Updated black version

This commit is contained in:
romi1502
2022-02-11 11:45:21 +00:00
parent fb7039b1a4
commit 0d64981fb8
12 changed files with 470 additions and 373 deletions

777
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -44,7 +44,7 @@ packages = [ { include = "spleeter" } ]
include = ["LICENSE", "spleeter/resources/*.json"]
[tool.poetry.dependencies]
python = ">=3.6.1,<3.10"
python = ">=3.6.2,<3.10"
ffmpeg-python = "0.2.0"
norbert = "0.2.1"
httpx = {extras = ["http2"], version = "^0.19.0"}
@@ -62,7 +62,7 @@ llvmlite = "^0.36.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.1"
isort = "^5.7.0"
black = "^20.8b1"
black = "^21.7b"
mypy = "^0.790"
pytest-forked = "^1.3.0"
musdb = "0.3.1"

View File

@@ -19,6 +19,6 @@ __license__ = "MIT License"
class SpleeterError(Exception):
""" Custom exception for Spleeter related error. """
"""Custom exception for Spleeter related error."""
pass

View File

@@ -251,7 +251,7 @@ def evaluate(
def entrypoint():
""" Application entrypoint. """
"""Application entrypoint."""
try:
spleeter()
except SpleeterError as e:

View File

@@ -18,7 +18,7 @@ __license__ = "MIT License"
class Codec(str, Enum):
""" Enumeration of supported audio codec. """
"""Enumeration of supported audio codec."""
WAV: str = "wav"
MP3: str = "mp3"
@@ -29,7 +29,7 @@ class Codec(str, Enum):
class STFTBackend(str, Enum):
""" Enumeration of supported STFT backend. """
"""Enumeration of supported STFT backend."""
AUTO: str = "auto"
TENSORFLOW: str = "tensorflow"

View File

@@ -28,7 +28,7 @@ __license__ = "MIT License"
class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """
"""An abstract class for manipulating audio signal."""
_DEFAULT: "AudioAdapter" = None
""" Default audio adapter singleton instance. """

View File

@@ -129,7 +129,7 @@ def get_validation_dataset(
class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """
"""Instrument based filter and mapper provider."""
def __init__(self, parent, instrument) -> None:
"""
@@ -148,7 +148,7 @@ class InstrumentDatasetBuilder(object):
self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample):
""" Load waveform for given sample. """
"""Load waveform for given sample."""
return dict(
sample,
**self._parent._audio_adapter.load_tf_waveform(
@@ -161,7 +161,7 @@ class InstrumentDatasetBuilder(object):
)
def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """
"""Compute spectrogram of the given sample."""
return dict(
sample,
**{
@@ -187,7 +187,7 @@ class InstrumentDatasetBuilder(object):
)
def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """
"""Convert given sample from float to unit."""
return dict(
sample,
**spectrogram_to_db_uint(
@@ -199,11 +199,11 @@ class InstrumentDatasetBuilder(object):
)
def filter_infinity(self, sample):
""" Filter infinity sample. """
"""Filter infinity sample."""
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """
"""Convert given sample from unit to float."""
return dict(
sample,
**{
@@ -219,7 +219,7 @@ class InstrumentDatasetBuilder(object):
""" """
def start(sample):
""" mid_segment_start """
"""mid_segment_start"""
return tf.cast(
tf.maximum(
tf.shape(sample[self._spectrogram_key])[0] / 2
@@ -239,14 +239,14 @@ class InstrumentDatasetBuilder(object):
)
def filter_shape(self, sample):
""" Filter badly shaped sample. """
"""Filter badly shaped sample."""
return check_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
)
def reshape_spectrogram(self, sample):
""" Reshape given sample. """
"""Reshape given sample."""
return dict(
sample,
**{
@@ -326,7 +326,7 @@ class DatasetBuilder(object):
)
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
"""Expands audio paths for the given sample."""
return dict(
sample,
**{
@@ -338,15 +338,15 @@ class DatasetBuilder(object):
)
def filter_error(self, sample):
""" Filter errored sample. """
"""Filter errored sample."""
return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample):
""" Filter waveform from sample. """
"""Filter waveform from sample."""
return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """
"""Ensure same size for vocals and mix spectrograms."""
def _reduce(sample):
return tf.reduce_min(
@@ -367,7 +367,7 @@ class DatasetBuilder(object):
)
def filter_short_segments(self, sample):
""" Filter out too short segment. """
"""Filter out too short segment."""
return tf.reduce_any(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
@@ -376,7 +376,7 @@ class DatasetBuilder(object):
)
def random_time_crop(self, sample):
""" Random time crop of 11.88s. """
"""Random time crop of 11.88s."""
return dict(
sample,
**sync_apply(
@@ -393,7 +393,7 @@ class DatasetBuilder(object):
)
def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """
"""Randomly time stretch the given sample."""
return dict(
sample,
**sync_apply(
@@ -406,7 +406,7 @@ class DatasetBuilder(object):
)
def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """
"""Randomly pitch shift the given sample."""
return dict(
sample,
**sync_apply(
@@ -420,7 +420,7 @@ class DatasetBuilder(object):
)
def map_features(self, sample):
""" Select features and annotation of the given sample. """
"""Select features and annotation of the given sample."""
input_ = {
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
}

View File

@@ -94,5 +94,5 @@ def apply_blstm(
def blstm(
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
""" Model function applier. """
"""Model function applier."""
return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -193,7 +193,7 @@ def apply_unet(
def unet(
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
""" Model function applier. """
"""Model function applier."""
return apply(apply_unet, input_tensor, instruments, params)

View File

@@ -51,7 +51,7 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """
"""A ModelProvider implementation backed on Github for remote storage."""
DEFAULT_HOST: str = "https://github.com"
DEFAULT_REPOSITORY: str = "deezer/spleeter"

View File

@@ -53,15 +53,15 @@ class DataGenerator(object):
"""
def __init__(self) -> None:
""" Default constructor. """
"""Default constructor."""
self._current_data = None
def update_data(self, data) -> None:
""" Replace internal data. """
"""Replace internal data."""
self._current_data = data
def __call__(self) -> Generator:
""" Generation process. """
"""Generation process."""
buffer = self._current_data
while buffer:
yield buffer
@@ -94,7 +94,7 @@ def create_estimator(params, MWF):
class Separator(object):
""" A wrapper class for performing separation. """
"""A wrapper class for performing separation."""
def __init__(
self,

View File

@@ -21,7 +21,7 @@ environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class TyperLoggerHandler(logging.Handler):
""" A custom logger handler that use Typer echo. """
"""A custom logger handler that use Typer echo."""
def emit(self, record: logging.LogRecord) -> None:
echo(self.format(record))