Pool two vectors together

68 views Asked by At

I have two arrays like below: A from 1 to 4 repeated 8 times, B from 1 to 8 repeated 4 times. I want to shuffle B but with one correlation condition on the final Matrix. I don't want the same B value (lets say 1 1 1 1) to appear in front of more than 3 different A values. To explain this better:

A: 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4
B: 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8 8 8

Acceptable shuffle

A:          1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4
B_shuffled: 8 8 8 8 2 3 2 1 2 1 5 6 4 3 7 1 7 1 7 2 7 5 5 4 6 5 6 6 3 4 3 4

Unacceptable shuffle (because 8 in B appeared in front of A: 1, 2, 3, 4 (more than three times)

A:          1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4
B_shuffled: 8 5 7 6 2 3 2 1 2 1 5 8 4 3 7 1 7 1 8 2 7 5 5 4 6 8 6 6 3 4 3 4
import numpy as np
A = np.arange(1, 4 + 1,1).tolist()
A = np.repeat(np.array(A[::1]), 8).tolist()
B = np.arange(1, 8 + 1,1).tolist()
B = np.repeat(np.array(B[::1]), 4).tolist()

I used random.shuffle(B) but couldn't figure out how to add distribution logic to it.

1

There are 1 answers

0
Woodford On

You can't add any type of distribution logic to np.random.shuffle and even if you could, your constraints aren't really suited to it. We're not stuck though. Instead of using a one-liner we can procedurally build up a shuffled array that satisfies your requirements.

from collections import defaultdict
import numpy as np

# start with a validation function to ensure we're doing the right thing
def validate_shuffle(A, B, max_A_values=3):
    assoc = defaultdict(set)
    for a, b in zip(A, B):
        assoc[b].add(a)
        if len(assoc[b]) > max_A_values:
            return False
    return True

# return a shuffled copy of B, or None if the operation failed
def pool_vectors(A, B, max_A_values=3):
    # shuffle a list of all of the indices in our array; our goal is to
    # assign B[n] -> B[indices[n]]
    indices = np.arange(len(A))
    np.random.shuffle(indices)

    # C holds our output
    C = np.ndarray(len(B), dtype=int)

    # this container tracks how many A values are associated with each B value
    assoc = defaultdict(set)

    # iterate through each B value
    for b_idx, b in enumerate(B):

        # iterate through each shuffled index value, starting where we left off last time
        for newpos_idx, newpos in enumerate(indices[b_idx:], start=b_idx):

            # check to see if assigning the B value to this index breaks our constraint
            if len(assoc[b] | {A[newpos]}) <= max_A_values:

                # this new index works but we might have had to go through multiple
                # indices that were rejected; swap the working index into the
                # front of the subarray of unused indices so it will be skipped
                # over on the next outer loop iteration
                indices[b_idx], indices[newpos_idx] = indices[newpos_idx], indices[b_idx]
                break
        else:
            # because we're building the list procedurally, there are random 
            # circumstances where we can "paint ourselves into a corner"; this
            # doesn't happen often but we'll fail by returning None for now
            return None

        # assign the B value to the output array at the chosen index
        C[newpos] = b
        assoc[b].add(A[newpos])
    return C

# test it out
A = np.array([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4])
B = np.array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8])

while True:
    C = pool_vectors(A, B)
    if C is not None:
        assert validate_shuffle(A, C)
        print(C)
        break

# --- output --- #

[1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4]
[5 1 4 4 4 1 2 4 8 2 3 1 8 5 7 1 5 7 6 6 6 5 7 3 3 3 8 7 8 6 2 2]

Note that there's nothing numpy-specific in this solution and unless you're using it surrounding code you can easily replace the np.* functionality with vanilla Python.