Updated black version

This commit is contained in:
romi1502
2022-02-11 11:45:21 +00:00
parent fb7039b1a4
commit 0d64981fb8
12 changed files with 470 additions and 373 deletions

777
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -44,7 +44,7 @@ packages = [ { include = "spleeter" } ]
include = ["LICENSE", "spleeter/resources/*.json"] include = ["LICENSE", "spleeter/resources/*.json"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.6.1,<3.10" python = ">=3.6.2,<3.10"
ffmpeg-python = "0.2.0" ffmpeg-python = "0.2.0"
norbert = "0.2.1" norbert = "0.2.1"
httpx = {extras = ["http2"], version = "^0.19.0"} httpx = {extras = ["http2"], version = "^0.19.0"}
@@ -62,7 +62,7 @@ llvmlite = "^0.36.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^6.2.1" pytest = "^6.2.1"
isort = "^5.7.0" isort = "^5.7.0"
black = "^20.8b1" black = "^21.7b"
mypy = "^0.790" mypy = "^0.790"
pytest-forked = "^1.3.0" pytest-forked = "^1.3.0"
musdb = "0.3.1" musdb = "0.3.1"

View File

@@ -19,6 +19,6 @@ __license__ = "MIT License"
class SpleeterError(Exception): class SpleeterError(Exception):
""" Custom exception for Spleeter related error. """ """Custom exception for Spleeter related error."""
pass pass

View File

@@ -251,7 +251,7 @@ def evaluate(
def entrypoint(): def entrypoint():
""" Application entrypoint. """ """Application entrypoint."""
try: try:
spleeter() spleeter()
except SpleeterError as e: except SpleeterError as e:

View File

@@ -18,7 +18,7 @@ __license__ = "MIT License"
class Codec(str, Enum): class Codec(str, Enum):
""" Enumeration of supported audio codec. """ """Enumeration of supported audio codec."""
WAV: str = "wav" WAV: str = "wav"
MP3: str = "mp3" MP3: str = "mp3"
@@ -29,7 +29,7 @@ class Codec(str, Enum):
class STFTBackend(str, Enum): class STFTBackend(str, Enum):
""" Enumeration of supported STFT backend. """ """Enumeration of supported STFT backend."""
AUTO: str = "auto" AUTO: str = "auto"
TENSORFLOW: str = "tensorflow" TENSORFLOW: str = "tensorflow"

View File

@@ -28,7 +28,7 @@ __license__ = "MIT License"
class AudioAdapter(ABC): class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """ """An abstract class for manipulating audio signal."""
_DEFAULT: "AudioAdapter" = None _DEFAULT: "AudioAdapter" = None
""" Default audio adapter singleton instance. """ """ Default audio adapter singleton instance. """

View File

@@ -129,7 +129,7 @@ def get_validation_dataset(
class InstrumentDatasetBuilder(object): class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """ """Instrument based filter and mapper provider."""
def __init__(self, parent, instrument) -> None: def __init__(self, parent, instrument) -> None:
""" """
@@ -148,7 +148,7 @@ class InstrumentDatasetBuilder(object):
self._max_spectrogram_key = f"max_{instrument}_spectrogram" self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample): def load_waveform(self, sample):
""" Load waveform for given sample. """ """Load waveform for given sample."""
return dict( return dict(
sample, sample,
**self._parent._audio_adapter.load_tf_waveform( **self._parent._audio_adapter.load_tf_waveform(
@@ -161,7 +161,7 @@ class InstrumentDatasetBuilder(object):
) )
def compute_spectrogram(self, sample): def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """ """Compute spectrogram of the given sample."""
return dict( return dict(
sample, sample,
**{ **{
@@ -187,7 +187,7 @@ class InstrumentDatasetBuilder(object):
) )
def convert_to_uint(self, sample): def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """ """Convert given sample from float to unit."""
return dict( return dict(
sample, sample,
**spectrogram_to_db_uint( **spectrogram_to_db_uint(
@@ -199,11 +199,11 @@ class InstrumentDatasetBuilder(object):
) )
def filter_infinity(self, sample): def filter_infinity(self, sample):
""" Filter infinity sample. """ """Filter infinity sample."""
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key])) return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample): def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """ """Convert given sample from unit to float."""
return dict( return dict(
sample, sample,
**{ **{
@@ -219,7 +219,7 @@ class InstrumentDatasetBuilder(object):
""" """ """ """
def start(sample): def start(sample):
""" mid_segment_start """ """mid_segment_start"""
return tf.cast( return tf.cast(
tf.maximum( tf.maximum(
tf.shape(sample[self._spectrogram_key])[0] / 2 tf.shape(sample[self._spectrogram_key])[0] / 2
@@ -239,14 +239,14 @@ class InstrumentDatasetBuilder(object):
) )
def filter_shape(self, sample): def filter_shape(self, sample):
""" Filter badly shaped sample. """ """Filter badly shaped sample."""
return check_tensor_shape( return check_tensor_shape(
sample[self._spectrogram_key], sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels), (self._parent._T, self._parent._F, self._parent._n_channels),
) )
def reshape_spectrogram(self, sample): def reshape_spectrogram(self, sample):
""" Reshape given sample. """ """Reshape given sample."""
return dict( return dict(
sample, sample,
**{ **{
@@ -326,7 +326,7 @@ class DatasetBuilder(object):
) )
def expand_path(self, sample): def expand_path(self, sample):
""" Expands audio paths for the given sample. """ """Expands audio paths for the given sample."""
return dict( return dict(
sample, sample,
**{ **{
@@ -338,15 +338,15 @@ class DatasetBuilder(object):
) )
def filter_error(self, sample): def filter_error(self, sample):
""" Filter errored sample. """ """Filter errored sample."""
return tf.logical_not(sample["waveform_error"]) return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample): def filter_waveform(self, sample):
""" Filter waveform from sample. """ """Filter waveform from sample."""
return {k: v for k, v in sample.items() if not k == "waveform"} return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample): def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """ """Ensure same size for vocals and mix spectrograms."""
def _reduce(sample): def _reduce(sample):
return tf.reduce_min( return tf.reduce_min(
@@ -367,7 +367,7 @@ class DatasetBuilder(object):
) )
def filter_short_segments(self, sample): def filter_short_segments(self, sample):
""" Filter out too short segment. """ """Filter out too short segment."""
return tf.reduce_any( return tf.reduce_any(
[ [
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
@@ -376,7 +376,7 @@ class DatasetBuilder(object):
) )
def random_time_crop(self, sample): def random_time_crop(self, sample):
""" Random time crop of 11.88s. """ """Random time crop of 11.88s."""
return dict( return dict(
sample, sample,
**sync_apply( **sync_apply(
@@ -393,7 +393,7 @@ class DatasetBuilder(object):
) )
def random_time_stretch(self, sample): def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """ """Randomly time stretch the given sample."""
return dict( return dict(
sample, sample,
**sync_apply( **sync_apply(
@@ -406,7 +406,7 @@ class DatasetBuilder(object):
) )
def random_pitch_shift(self, sample): def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """ """Randomly pitch shift the given sample."""
return dict( return dict(
sample, sample,
**sync_apply( **sync_apply(
@@ -420,7 +420,7 @@ class DatasetBuilder(object):
) )
def map_features(self, sample): def map_features(self, sample):
""" Select features and annotation of the given sample. """ """Select features and annotation of the given sample."""
input_ = { input_ = {
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"] f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
} }

View File

@@ -94,5 +94,5 @@ def apply_blstm(
def blstm( def blstm(
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor: ) -> tf.Tensor:
""" Model function applier. """ """Model function applier."""
return apply(apply_blstm, input_tensor, output_name, params) return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -193,7 +193,7 @@ def apply_unet(
def unet( def unet(
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict: ) -> Dict:
""" Model function applier. """ """Model function applier."""
return apply(apply_unet, input_tensor, instruments, params) return apply(apply_unet, input_tensor, instruments, params)

View File

@@ -51,7 +51,7 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider): class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """ """A ModelProvider implementation backed on Github for remote storage."""
DEFAULT_HOST: str = "https://github.com" DEFAULT_HOST: str = "https://github.com"
DEFAULT_REPOSITORY: str = "deezer/spleeter" DEFAULT_REPOSITORY: str = "deezer/spleeter"

View File

@@ -53,15 +53,15 @@ class DataGenerator(object):
""" """
def __init__(self) -> None: def __init__(self) -> None:
""" Default constructor. """ """Default constructor."""
self._current_data = None self._current_data = None
def update_data(self, data) -> None: def update_data(self, data) -> None:
""" Replace internal data. """ """Replace internal data."""
self._current_data = data self._current_data = data
def __call__(self) -> Generator: def __call__(self) -> Generator:
""" Generation process. """ """Generation process."""
buffer = self._current_data buffer = self._current_data
while buffer: while buffer:
yield buffer yield buffer
@@ -94,7 +94,7 @@ def create_estimator(params, MWF):
class Separator(object): class Separator(object):
""" A wrapper class for performing separation. """ """A wrapper class for performing separation."""
def __init__( def __init__(
self, self,

View File

@@ -21,7 +21,7 @@ environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class TyperLoggerHandler(logging.Handler): class TyperLoggerHandler(logging.Handler):
""" A custom logger handler that use Typer echo. """ """A custom logger handler that use Typer echo."""
def emit(self, record: logging.LogRecord) -> None: def emit(self, record: logging.LogRecord) -> None:
echo(self.format(record)) echo(self.format(record))