I am implementing a noise gate in Python, using the state design pattern.
My implementation takes an array of audio samples and, using the parameters of the noise gate, the audio sample magnitude values, and the state of the noise gate, determines a coefficient value in the range [0, 1] which should be multiplied with the current audio sample value.
The states I have defined are OpenState
, ClosedState
, OpeningState
and ClosingState
. I believe the image below contains all of the state transitions I need to consider.
When the gate is in ClosingState
, there are two possible transitions:
ClosingState
->ClosedState
- this occurs if the release period elapses without another peak exceeding the threshold during that time.ClosingState
->OpenState
- this occurs if a peak exceeds the threshold at some point during the release period.
The part of my code that decides which state to transition to is this method inside the ClosingState
class.
def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
'''
There are two possible states that we can transition to from ClosingState.
Feels strange to introduce conditionals to determine state transition(?)
'''
# This doesn't feel right introducing these conditionals here.
if sample_mag > self.context.lin_thresh:
self.transition_pending = True
self.new_state = OpenState()
return True
if self.sample_counter >= self.context.release_period_in_samples-1:
self.transition_pending = True
self.new_state = ClosedState()
return True
My question is simply whether it is OK to use these conditionals to determine which state to transition to. It feels like re-introducing the type of code that using the state pattern gets rid of, but an alternative is not obvious to me.
EDIT: I think I have clarified this for myself. Using the state pattern allows us to get rid of code that looks like this:
if state == "closed":
# do something
elif state == "open":
# do something
elif state == "closing":
# do something
elif state == "opening":
# do something
The conditionals whose validity I was questioning are not the same as this. I am checking some conditions based on data, rather than checking which state I am in.
Below is a minimum example. This may not be needed for my conceptual question above, but I am including it in case it is. The sample audio file can be found here.
SO_ramp_functions.py
import numpy as np
def ramp_linear_increase(num_points):
''' Function defining a linear increase from 0 to 1 in num_points samples '''
return np.linspace(0, 1, num_points)
def ramp_linear_decrease(num_points):
''' Function defining a linear decrease from 1 to 0 in num_points samples '''
return np.linspace(1, 0, num_points)
def ramp_poly_increase(num_points):
''' Generate an array of coefficient values for the attack period '''
x = np.arange(num_points, 0, -1)
attack_coef_arr = 1 - (x/num_points)**4
# Make sure the start and end are 0 and 1, respectively
attack_coef_arr[0] = 0
attack_coef_arr[-1] = 1
return attack_coef_arr
def ramp_poly_decrease(num_points):
''' Generate an array of coefficient values for the release period '''
x = np.arange(num_points)
release_coef_arr = 1 - (x/num_points)**4
# Make sure the start and end are 1 and 0, respectively
release_coef_arr[0] = 1
release_coef_arr[-1] = 0
return release_coef_arr
SO_gate_states.py
from abc import ABC, abstractmethod
class State(ABC):
"""
The base State class declares methods that all concrete States should
implement and also provides a backreference to the Context object,
associated with the State. This backreference can be used by States to
transition the Context to another State.
"""
@property
def context(self):
return self._context
@context.setter
def context(self, context) -> None:
self._context = context
@abstractmethod
def get_sample_coefficient(self, sample_mag: float) -> float:
pass
@abstractmethod
def check_if_state_transition_is_due(self, sample_mag: float=None) -> None:
pass
@abstractmethod
def on_entry(self):
pass
@abstractmethod
def on_exit(self):
pass
"""
Concrete States implement various behaviors, associated with a state of the
Context.
"""
class ClosedState(State):
def __init__(self):
self.sample_counter = 0
self.transition_pending = False
def get_sample_coefficient(self, sample_mag: float=0) -> float:
'''
Get the appropriate coefficient value to multiply with the current
audio sample value.
In the closed state, the coefficient is always 0.0.
'''
self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
return 0.0
def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
'''
Check if a condition is met that initiates a transition.
For ClosedState, we want to check if the sample magnitude exceeds the threshold.
'''
return sample_mag > self.context.lin_thresh
def on_entry(self):
pass
def on_exit(self):
pass
def handle_state_transition(self):
if self.transition_pending:
self.context.transition_to(OpeningState())
class OpeningState(State):
'''
- In OpeningState, the coefficient is determined by the shape of the
specified attack ramp.
- The only state we can transition to from OpeningState is OpenState.
'''
def __init__(self):
self.sample_counter = 0
self.transition_pending = False
def get_sample_coefficient(self, sample_mag: float=0) -> float:
self.transition_pending = self.check_if_state_transition_is_due()
if self.transition_pending:
return 1.0
else:
# Get a value from the gate's attack ramp
return self.context.attack_ramp[self.sample_counter]
def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
# Transition to OpenState occurs once attack period has elapsed
return self.sample_counter >= self.context.attack_period_in_samples
def handle_state_transition(self):
if self.transition_pending:
self.context.transition_to(OpenState())
self.on_exit()
def on_entry(self):
pass
def on_exit(self):
# This may not be needed, since we construct a new instance when
# transitioning, but it may make it more robust
self.sample_counter = 0
class OpenState(State):
'''
In OpenState, the coefficient is always 1.0.
The only state we can transition to from OpenState is ClosingState.
'''
def __init__(self):
self.sample_counter = 0
self.transition_pending = False
def get_sample_coefficient(self, sample_mag: float=0) -> float:
self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
return 1.0
def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
# The gate can't transition before its hold period has elapsed
if self.sample_counter < self.context.hold_period_in_samples:
return False
else:
# If the signal magnitude falls below the threshold, we want to
# transition to ClosingState.
return sample_mag < self.context.lin_thresh
def on_entry(self):
pass
def on_exit(self):
# This may not be needed, since we construct a new instance when
# transitioning, but it may make it more robust
self.sample_counter = 0
def handle_state_transition(self):
if self.transition_pending:
self.context.transition_to(ClosingState())
self.on_exit()
class ClosingState(State):
'''
- The coefficient is determined by the shape of the specified release ramp.
- The state can transition to either ClosedState or OpenState.
'''
def __init__(self):
self.sample_counter = 0
self.transition_pending = False
self.new_state = None
def get_sample_coefficient(self, sample_mag: float=0) -> float:
self.transition_pending = self.check_if_state_transition_is_due(sample_mag)
return self.context.release_ramp[self.sample_counter]
def check_if_state_transition_is_due(self, sample_mag: float=0) -> None:
'''
There are two possible states that we can transition to from ClosingState.
Feels strange to introduce conditionals to determine state transition(?)
'''
# This doesn't feel right introducing these conditionals here.
if sample_mag > self.context.lin_thresh:
self.transition_pending = True
self.new_state = OpenState()
return True
if self.sample_counter >= self.context.release_period_in_samples-1:
self.transition_pending = True
self.new_state = ClosedState()
return True
def handle_state_transition(self):
if self.transition_pending:
self.context.transition_to(self.new_state)
self.on_exit()
def on_entry(self):
pass
def on_exit(self):
# This may not be needed, since we construct a new instance when
# transitioning, but it may make it more robust
self.sample_counter = 0
SO_noise_gate_state_pattern.py
import numpy as np
import SO_ramp_functions as rf
'''
The original template code is found here:
https://refactoring.guru/design-patterns/state/python/example
'''
class AudioConfig:
'''
Values that configure audio playback, so they can be set indepdendently
of, and shared between, different objects that need them.
'''
def __init__(self, fs):
self.fs = fs
class Context:
"""
This class represents the noise gate.
The Context defines the interface of interest to clients. It also maintains
a reference to an instance of a State subclass, which represents the current
state of the Context.
"""
def __init__(self, audio_config, state) -> None:
self.audio_config = audio_config
self.transition_to(state)
# Specify an initial threshold value in dBFS
self.thresh = -20
# Specify attack, hold, release, and lookahead periods in seconds
self.attack_time = 0.005 # seconds
self.hold_time = 0.05 # seconds
self.release_time = 0.1 # seconds
self.lookahead_time = 0.005 # seconds
# Calculate attack, hold, and release periods in samples
self.attack_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.attack_time)
self.hold_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.hold_time)
self.release_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.release_time)
self.lookahead_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.lookahead_time)
# Define the attack and release multiplier ramps - use strategy pattern?
self.attack_ramp = rf.ramp_poly_increase(num_points=self.attack_period_in_samples)
self.release_ramp = rf.ramp_poly_decrease(num_points=self.release_period_in_samples)
# Initialise an attribute to store the processed result
self.processed_array = None
self.coef_array = None
# Padding to enable lookahead (a bit of a hack)
self.lookahead_pad_samples = self.lookahead_period_in_samples#2000
# Attributes for debugging
self.text_output = []
def transition_to(self, state):
"""
The Context allows changing the State object at runtime.
"""
##print(f"Context: Transition to {type(state).__name__}")
self._state = state
self._state.context = self
# Setters for gate parameters
def set_attack_time(self, new_attack_time: float) -> None:
self.attack_time = new_attack_time
self.attack_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.attack_time)
def set_hold_time(self, new_hold_time: float) -> None:
self.hold_time = new_hold_time
self.hold_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.hold_time)
def set_release_time(self, new_release_time: float) -> None:
self.release_time = new_release_time
self.release_period_in_samples = self.seconds_to_samples(self.audio_config.fs, self.release_time)
@property
def thresh(self) -> int:
return self._thresh
@thresh.setter
def thresh(self, new_thresh: int) -> None:
self._thresh = new_thresh
@property
def lin_thresh(self) -> float:
return self.dBFS_to_lin(self.thresh)
# These staticmethods could equally be defined outside the class
@staticmethod
def dBFS_to_lin(dBFS_val):
''' Helper method to convert a dBFS value to a linear value [0, 1] '''
return 10 ** (dBFS_val / 20)
@staticmethod
def seconds_to_samples(fs, seconds_val):
''' Helper method to convert a time (seconds) value to a number of samples '''
return int(fs * seconds_val)
def process_audio_block(self, audio_array=None):
'''
Process an array of audio samples according to the gate's parameters,
current state, and the sample values in the audio array.
This implementation includes lookahead logic.
'''
# Initialise an array of coefficient values of the same length as audio_array
# Set initial coefficient values outside valid range [0, 1] for easier debugging
self.coef_array = np.ones(len(audio_array))[:-self.lookahead_pad_samples] * 2
# Get the magnitude values of the audio array
self.mag_array = np.abs(audio_array)
# Iterate through the samples of the mag_arr, updating coef_array values
for i, sample_mag in enumerate(self.mag_array[:-self.lookahead_pad_samples]):
# Get the coefficient value for the current sample, considering a lookahead period
self.coef_array[i] = self._state.get_sample_coefficient(self.mag_array[i + self.lookahead_period_in_samples])
# Increment the counter for tracking the samples elapsed in the current state
self._state.sample_counter += 1
# Create a log of the state and samples elapsed, for debugging
self.text_output.append(f"{type(self._state).__name__}. {self._state.sample_counter}. {self.coef_array[i]:.3f}")
# After processing the current sample, check if a transition is due
self._state.handle_state_transition()
self.processed_array = self.coef_array * audio_array[:-self.lookahead_pad_samples]
main.py
'''
Driver code for the noise gate using the state pattern.
'''
from SO_noise_gate_state_pattern import AudioConfig, Context
from SO_gate_states import ClosedState
import numpy as np
import audiofile
import matplotlib.pyplot as plt
import time
# Define some helper/test functions
def load_audio(fpath):
data, fs = audiofile.read(fpath)
data = data.T
if len(data.shape) == 2:
data = data[:,0] # convert to mono
return data
def test_gate_coef_values_are_valid(coef_arr):
print("Testing gate coef_array values")
assert(np.all([0<=val<=1 for val in coef_arr]))
if __name__ == "__main__":
# The client code.
# Configure some audio properties
audio_config = AudioConfig(fs=44100)
# Create a "context" instance (this is like the NoiseGate class)
context = Context(audio_config, ClosedState())
# Load audio from file
sig = load_audio(fpath="./snare_test.wav")
# Zero-pad the audio array to enable lookahead (experimental)
sig = np.concatenate((sig, np.zeros(context.lookahead_pad_samples)))
# Process the whole array and time it
start_time = time.perf_counter()
context.process_audio_block(sig)
end_time = time.perf_counter()
print(f"Time taken to process {len(sig)/audio_config.fs:.2f} seconds of audio: {end_time - start_time:.2f} seconds")
# Some testing on the result
test_gate_coef_values_are_valid(context.coef_array)
# Plot the result
plt.plot(context.mag_array, color='blue', linewidth=1, label='signal magnitude')
plt.plot(context.coef_array, color='green', label='gate coefficient')
plt.plot(np.abs(context.processed_array), color='orange', label='gate output')
plt.axhline(context.lin_thresh, color='black', linewidth=1, label='gate threshold')
plt.legend()
plt.show()
It's ok to put logic to select the correct state. But if managing that code is your concern, you can use other patterns to manage the complexity. I think the factory-method pattern or chain of responsibility pattern can be useful. However overusing design patterns may make your code complex. I would rather wrap the codes that can vary and encapsulate them in a function with a meaningful name and make my state_transition method clear.
So if your conditions are fragile (they need to be flexible to accept future changes), make functions to represent them. But if your logic needs to be flexible, make functions for your transition logic.
Remember that the goal of the state pattern is to separate each state to manage them independently, so it would decrease the side effects. It's not the intention of the state pattern to reduce the complexity of each state. To do that you should consider applying other patterns.