diff --git a/dbldatagen/text_generators.py b/dbldatagen/text_generators.py index 965350be..bcad0a07 100644 --- a/dbldatagen/text_generators.py +++ b/dbldatagen/text_generators.py @@ -10,6 +10,7 @@ import random import logging +from abc import ABC, abstractmethod import numpy as np import pandas as pd @@ -61,7 +62,7 @@ 'LABORUM'] -class TextGenerator(object): +class TextGenerator(ABC): """ Base class for text generation classes """ @@ -161,6 +162,10 @@ def getAsTupleOrElse(v, defaultValue, valueName): return defaultValue + @abstractmethod + def pandasGenerateText(self, v): + raise NotImplementedError("Subclasses should implement unique versions of `pandasGenerateText`") + class TemplateGenerator(TextGenerator): # lgtm [py/missing-equals] """This class handles the generation of text from templates diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index fb23d9d3..bec8df11 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -1,9 +1,9 @@ import re -import pytest -import pandas as pd -import numpy as np +import numpy as np +import pandas as pd import pyspark.sql.functions as F +import pytest from pyspark.sql.types import BooleanType, DateType from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType @@ -44,9 +44,13 @@ class TestTextGeneration: row_count = 100000 partitions_requested = 4 + class TestTextGenerator(TextGenerator): + def pandasGenerateText(self, v): # pylint: disable=useless-parent-delegation + return super().pandasGenerateText(v) + def test_text_generator_basics(self): # test the random humber generator - tg1 = TextGenerator() + tg1 = self.TestTextGenerator() # test the repr desc = repr(tg1) diff --git a/tests/test_text_generator_basic.py b/tests/test_text_generator_basic.py index 238212e7..d3a5d8f4 100644 --- a/tests/test_text_generator_basic.py +++ b/tests/test_text_generator_basic.py @@ -1,7 +1,8 @@ import re -import pytest + import numpy as np import pandas as pd +import pytest from dbldatagen import TextGenerator, TemplateGenerator @@ -12,10 +13,14 @@ class TestTextGeneratorBasic: row_count = 100000 partitions_requested = 4 + class TestTextGenerator(TextGenerator): + def pandasGenerateText(self, v): # pylint: disable=useless-parent-delegation + return super().pandasGenerateText(v) + @pytest.mark.parametrize("randomSeed", [None, 0, -1, 2112, 42]) def test_text_generator_basic(self, randomSeed): - text_gen1 = TextGenerator() - text_gen2 = TextGenerator() + text_gen1 = self.TestTextGenerator() + text_gen2 = self.TestTextGenerator() if randomSeed is not None: text_gen1 = text_gen1.withRandomSeed(randomSeed) @@ -29,14 +34,19 @@ def test_text_generator_basic(self, randomSeed): assert text_gen1 == text_gen2 + def test_base_textgenerator_raises_error(self): + with pytest.raises(NotImplementedError): + text_gen1 = self.TestTextGenerator() + text_gen1.pandasGenerateText(None) + @pytest.mark.parametrize("randomSeed, forceNewInstance", [(None, True), (None, False), (0, True), (0, False), (-1, True), (-1, False), (2112, True), (2112, False), (42, True), (42, False)]) def test_text_generator_rng(self, randomSeed, forceNewInstance): - text_gen1 = TextGenerator() - text_gen2 = TextGenerator() + text_gen1 = self.TestTextGenerator() + text_gen2 = self.TestTextGenerator() if randomSeed is not None: text_gen1 = text_gen1.withRandomSeed(randomSeed) @@ -71,7 +81,7 @@ def test_text_generator_rng(self, randomSeed, forceNewInstance): (np.array([1, 40000.4, 3]), np.uint16) ]) def test_text_generator_compact_types(self, values, expectedType): - text_gen1 = TextGenerator() + text_gen1 = self.TestTextGenerator() np_type = text_gen1.compactNumpyTypeForValues(values) assert np_type == expectedType