How can I "dry" up this pytest-mock?

41 views Asked by At

I have some test methods that require a mock of a class. The below code works but I am repeating the same mock in several methods. Here is the code that is getting repeated. You can see it used below in the last 3 methods.

class PyWaves(object):
    def names(self):
        return test_class_self.mock_wave_names()

I tried moving it to the test class's __init__ but that does not work. What's the right way to avoid repeating this code in each method where I need the mock?

class TestExtractedXPSNameAdapter:

    # def __init__(self, mocker):
    #     class PyWaves(object):
    #         def names(self):
    #             return self.mock_wave_names()
    #     self.waves = PyWaves()


    def mock_wave_names(self): 
        return ['net_a', 'top.net_b', 'top.foo@bar@net1#foo@bar@inst@0_g', 'top.foo1@bar1@net2#foo1@bar1@inst1_d']
    
    def expected_index(self): 
        return {'net_a': 'net_a', 'top.net_b': 'top.net_b', 'foo@bar@inst@0_g': 'top.foo@bar@net1#foo@bar@inst@0_g', 'foo1@bar1@inst1_d': 'top.foo1@bar1@net2#foo1@bar1@inst1_d'}

    def test_gen_populated_name_guesses(self):
        adapter = ExtractedXPSNameAdapter()
        identifier = WaveformIdentifier(instance_name="my_inst", path="/i_macro/sub1/sub2", term_name="my_terminal")
        guesses = adapter.gen_name_guesses(identifier)
        expected_guesses = [
            "i_macro@sub1@sub2@my_inst@0_my_terminal",
            "i_macro@sub1@sub2@my_inst_my_terminal",
            ]
        assert guesses == expected_guesses

    def test_gen_empty_name_guesses(self):
        adapter = ExtractedXPSNameAdapter()
        identifier = WaveformIdentifier(instance_name="my_inst", path="/i_macro/sub1/sub2")
        guesses = adapter.gen_name_guesses(identifier)
        expected_guesses = []
        assert guesses == expected_guesses

    def test_build_index(self, mocker):
        test_class_self = self
        class PyWaves(object):
            def names(self):
                return test_class_self.mock_wave_names()
        waves = PyWaves()
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(waves)
        adapter.build_index()
        assert adapter.index == self.expected_index()

    def test_waveform_name(self, mocker):
        test_class_self = self
        class PyWaves(object):
            def names(self):
                return test_class_self.mock_wave_names()
        waves = PyWaves()
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(waves)
        identifier = WaveformIdentifier(instance_name='inst1', path='/foo1/bar1', term_name="d")
        name = adapter.waveform_name(identifier)
        assert name == 'top.foo1@bar1@net2#foo1@bar1@inst1_d'

    def test_info_from_identifier(self, mocker):
        test_class_self = self
        class PyWaves(object):
            def names(self):
                return test_class_self.mock_wave_names()
        waves = PyWaves()
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(waves)
        identifier = WaveformIdentifier(instance_name='inst1', path='/foo1/bar1', term_name="d")
        expected = WaveformNameInfo(identifier=identifier, type=identifier.type, full_name='top.foo1@bar1@net2#foo1@bar1@inst1_d', path=identifier.path, terminal_name=identifier.term_name)
        assert adapter.info_from_identifier(identifier) == expected
1

There are 1 answers

2
Samwise On BEST ANSWER

Define PyWaves at the global scope, and inject the mock function (or the object, but since in this case you just want one method, why complicate it?):

class PyWaves:
    def __init__(self, mock_wave_names_fn):
        self.mock_wave_names = mock_wave_names_fn

    def names(self):
        return self.mock_wave_names()

and then you can replace this:

    def test_build_index(self, mocker):
        test_class_self = self
        class PyWaves(object):
            def names(self):
                return test_class_self.mock_wave_names()
        waves = PyWaves()
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(waves)
        adapter.build_index()
        assert adapter.index == self.expected_index()

with a DRYer version:

    def test_build_index(self, mocker):
        adapter = ExtractedXPSNameAdapter()
        adapter.set_wave_list(PyWaves(self.mock_wave_names))
        adapter.build_index()
        assert adapter.index == self.expected_index()