mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
@@ -16,3 +16,4 @@ dependencies:
|
|||||||
- musdb==0.3.1
|
- musdb==0.3.1
|
||||||
- norbert==0.2.1
|
- norbert==0.2.1
|
||||||
- spleeter
|
- spleeter
|
||||||
|
- ffmpeg-python
|
||||||
|
|||||||
@@ -16,4 +16,5 @@ dependencies:
|
|||||||
- musdb==0.3.1
|
- musdb==0.3.1
|
||||||
- norbert==0.2.1
|
- norbert==0.2.1
|
||||||
- spleeter
|
- spleeter
|
||||||
|
- ffmpeg-python
|
||||||
|
|
||||||
|
|||||||
9
setup.py
9
setup.py
@@ -14,7 +14,7 @@ __license__ = 'MIT License'
|
|||||||
|
|
||||||
# Default project values.
|
# Default project values.
|
||||||
project_name = 'spleeter'
|
project_name = 'spleeter'
|
||||||
project_version = '1.4.1'
|
project_version = '1.4.2'
|
||||||
device_target = 'cpu'
|
device_target = 'cpu'
|
||||||
tensorflow_dependency = 'tensorflow'
|
tensorflow_dependency = 'tensorflow'
|
||||||
tensorflow_version = '1.14.0'
|
tensorflow_version = '1.14.0'
|
||||||
@@ -63,14 +63,17 @@ setup(
|
|||||||
python_requires='>=3.6, <3.8',
|
python_requires='>=3.6, <3.8',
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
'ffmpeg-python',
|
||||||
'importlib_resources ; python_version<"3.7"',
|
'importlib_resources ; python_version<"3.7"',
|
||||||
'musdb==0.3.1',
|
|
||||||
'museval==0.3.0',
|
|
||||||
'norbert==0.2.1',
|
'norbert==0.2.1',
|
||||||
'pandas==0.25.1',
|
'pandas==0.25.1',
|
||||||
'requests',
|
'requests',
|
||||||
|
'setuptools>=41.0.0',
|
||||||
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
||||||
],
|
],
|
||||||
|
extras_require={
|
||||||
|
'evaluation': ['musdb==0.3.1', 'museval==0.3.0']
|
||||||
|
},
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': ['spleeter=spleeter.__main__:entrypoint']
|
'console_scripts': ['spleeter=spleeter.__main__:entrypoint']
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import warnings
|
|||||||
|
|
||||||
from .commands import create_argument_parser
|
from .commands import create_argument_parser
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
from .utils.logging import enable_logging, enable_verbose_logging
|
from .utils.logging import enable_logging, enable_tensorflow_logging
|
||||||
|
|
||||||
__email__ = 'research@deezer.com'
|
__email__ = 'research@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
@@ -28,10 +28,9 @@ def main(argv):
|
|||||||
"""
|
"""
|
||||||
parser = create_argument_parser()
|
parser = create_argument_parser()
|
||||||
arguments = parser.parse_args(argv[1:])
|
arguments = parser.parse_args(argv[1:])
|
||||||
|
enable_logging()
|
||||||
if arguments.verbose:
|
if arguments.verbose:
|
||||||
enable_verbose_logging()
|
enable_tensorflow_logging()
|
||||||
else:
|
|
||||||
enable_logging()
|
|
||||||
if arguments.command == 'separate':
|
if arguments.command == 'separate':
|
||||||
from .commands.separate import entrypoint
|
from .commands.separate import entrypoint
|
||||||
elif arguments.command == 'train':
|
elif arguments.command == 'train':
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ def create_argument_parser():
|
|||||||
|
|
||||||
:returns: Created argument parser.
|
:returns: Created argument parser.
|
||||||
"""
|
"""
|
||||||
parser = ArgumentParser(prog='python -m spleeter')
|
parser = ArgumentParser(prog='spleeter')
|
||||||
subparsers = parser.add_subparsers()
|
subparsers = parser.add_subparsers()
|
||||||
subparsers.dest = 'command'
|
subparsers.dest = 'command'
|
||||||
subparsers.required = True
|
subparsers.required = True
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
--mus_dir /path/to/musdb dataset
|
--mus_dir /path/to/musdb dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
@@ -21,8 +22,6 @@ from glob import glob
|
|||||||
from os.path import join, exists
|
from os.path import join, exists
|
||||||
|
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import musdb
|
|
||||||
import museval
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
@@ -30,6 +29,15 @@ import pandas as pd
|
|||||||
from .separate import entrypoint as separate_entrypoint
|
from .separate import entrypoint as separate_entrypoint
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
import musdb
|
||||||
|
import museval
|
||||||
|
except ImportError:
|
||||||
|
logger = get_logger()
|
||||||
|
logger.error('Extra dependencies musdb and museval not found')
|
||||||
|
logger.error('Please install musdb and museval first, abort')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
__email__ = 'research@deezer.com'
|
__email__ = 'research@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|||||||
@@ -129,7 +129,6 @@ def process_audio(
|
|||||||
yield_single_examples=False)
|
yield_single_examples=False)
|
||||||
# initialize pool for audio export
|
# initialize pool for audio export
|
||||||
pool = Pool(16)
|
pool = Pool(16)
|
||||||
tasks = []
|
|
||||||
for sample in prediction:
|
for sample in prediction:
|
||||||
sample_filename = sample.pop('audio_id', 'unknown_filename').decode()
|
sample_filename = sample.pop('audio_id', 'unknown_filename').decode()
|
||||||
input_directory, input_filename = split(sample_filename)
|
input_directory, input_filename = split(sample_filename)
|
||||||
@@ -144,13 +143,12 @@ def process_audio(
|
|||||||
output_path,
|
output_path,
|
||||||
output_dirname,
|
output_dirname,
|
||||||
f'{instrument}.{codec}')
|
f'{instrument}.{codec}')
|
||||||
tasks.append(
|
pool.apply_async(
|
||||||
pool.apply_async(
|
audio_adapter.save,
|
||||||
audio_adapter.save,
|
(filename, waveform, sample_rate, codec))
|
||||||
(filename, waveform, sample_rate, codec)))
|
|
||||||
# Wait for everything to be written
|
# Wait for everything to be written
|
||||||
for task in tasks:
|
pool.close()
|
||||||
task.wait(timeout=20)
|
pool.join()
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
def entrypoint(arguments, params):
|
||||||
|
|||||||
@@ -3,10 +3,9 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
This module contains building functions for U-net source separation source
|
This module contains building functions for U-net source separation source
|
||||||
separation models.
|
separation models. Each instrument is modeled by a single U-netconvolutional
|
||||||
Each instrument is modeled by a single U-net convolutional/deconvolutional
|
/ deconvolutional network that take a mix spectrogram as input and the
|
||||||
network that take a mix spectrogram as input and the estimated sound spectrogram
|
estimated sound spectrogram as output.
|
||||||
as output.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|||||||
@@ -65,9 +65,9 @@ class GithubModelProvider(ModelProvider):
|
|||||||
raise IOError(f'Resource {url} not found')
|
raise IOError(f'Resource {url} not found')
|
||||||
with TemporaryFile() as stream:
|
with TemporaryFile() as stream:
|
||||||
copyfileobj(response.raw, stream)
|
copyfileobj(response.raw, stream)
|
||||||
get_logger().debug('Extracting downloaded archive')
|
get_logger().info('Extracting downloaded %s archive', name)
|
||||||
stream.seek(0)
|
stream.seek(0)
|
||||||
tar = tarfile.open(fileobj=stream)
|
tar = tarfile.open(fileobj=stream)
|
||||||
tar.extractall(path=path)
|
tar.extractall(path=path)
|
||||||
tar.close()
|
tar.close()
|
||||||
get_logger().debug('Model file extracted')
|
get_logger().info('%s model file(s) extracted', name)
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ class AudioAdapter(ABC):
|
|||||||
duration.numpy(),
|
duration.numpy(),
|
||||||
sample_rate.numpy(),
|
sample_rate.numpy(),
|
||||||
dtype=dtype.numpy())
|
dtype=dtype.numpy())
|
||||||
|
get_logger().info('Audio data loaded successfully')
|
||||||
return (data, False)
|
return (data, False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(e)
|
get_logger().warning(e)
|
||||||
|
|||||||
@@ -9,12 +9,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import os.path
|
|
||||||
import platform
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import numpy as np # pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
|
import ffmpeg
|
||||||
|
import numpy as np
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from .adapter import AudioAdapter
|
from .adapter import AudioAdapter
|
||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
@@ -23,58 +22,9 @@ __email__ = 'research@deezer.com'
|
|||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
# Default FFMPEG binary name.
|
|
||||||
_UNIX_BINARY = 'ffmpeg'
|
|
||||||
_WINDOWS_BINARY = 'ffmpeg.exe'
|
|
||||||
|
|
||||||
|
|
||||||
def _which(program):
|
|
||||||
""" A pure python implementation of `which`command
|
|
||||||
for retrieving absolute path from command name or path.
|
|
||||||
|
|
||||||
@see https://stackoverflow.com/a/377028/1211342
|
|
||||||
|
|
||||||
:param program: Program name or path to expend.
|
|
||||||
:returns: Absolute path of program if any, None otherwise.
|
|
||||||
"""
|
|
||||||
def is_exe(fpath):
|
|
||||||
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
|
|
||||||
|
|
||||||
fpath, _ = os.path.split(program)
|
|
||||||
if fpath:
|
|
||||||
if is_exe(program):
|
|
||||||
return program
|
|
||||||
else:
|
|
||||||
for path in os.environ['PATH'].split(os.pathsep):
|
|
||||||
exe_file = os.path.join(path, program)
|
|
||||||
if is_exe(exe_file):
|
|
||||||
return exe_file
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_ffmpeg_path():
|
|
||||||
""" Retrieves FFMPEG binary path using ENVVAR if defined
|
|
||||||
or default binary name (Windows or UNIX style).
|
|
||||||
|
|
||||||
:returns: Absolute path of FFMPEG binary.
|
|
||||||
:raise IOError: If FFMPEG binary cannot be found.
|
|
||||||
"""
|
|
||||||
ffmpeg_path = os.environ.get('FFMPEG_PATH', None)
|
|
||||||
if ffmpeg_path is None:
|
|
||||||
# Note: try to infer standard binary name regarding of platform.
|
|
||||||
if platform.system() == 'Windows':
|
|
||||||
ffmpeg_path = _WINDOWS_BINARY
|
|
||||||
else:
|
|
||||||
ffmpeg_path = _UNIX_BINARY
|
|
||||||
expended = _which(ffmpeg_path)
|
|
||||||
if expended is None:
|
|
||||||
raise IOError(f'FFMPEG binary ({ffmpeg_path}) not found')
|
|
||||||
return expended
|
|
||||||
|
|
||||||
|
|
||||||
def _to_ffmpeg_time(n):
|
def _to_ffmpeg_time(n):
|
||||||
""" Format number of seconds to time expected by FFMPEG.
|
""" Format number of seconds to time expected by FFMPEG.
|
||||||
|
|
||||||
:param n: Time in seconds to format.
|
:param n: Time in seconds to format.
|
||||||
:returns: Formatted time in FFMPEG format.
|
:returns: Formatted time in FFMPEG format.
|
||||||
"""
|
"""
|
||||||
@@ -83,56 +33,6 @@ def _to_ffmpeg_time(n):
|
|||||||
return '%d:%02d:%09.6f' % (h, m, s)
|
return '%d:%02d:%09.6f' % (h, m, s)
|
||||||
|
|
||||||
|
|
||||||
def _parse_ffmpg_results(stderr):
|
|
||||||
""" Extract number of channels and sample rate from
|
|
||||||
the given FFMPEG STDERR output line.
|
|
||||||
|
|
||||||
:param stderr: STDERR output line to parse.
|
|
||||||
:returns: Parsed n_channels and sample_rate values.
|
|
||||||
"""
|
|
||||||
# Setup default value.
|
|
||||||
n_channels = 0
|
|
||||||
sample_rate = 0
|
|
||||||
# Find samplerate
|
|
||||||
match = re.search(r'(\d+) hz', stderr)
|
|
||||||
if match:
|
|
||||||
sample_rate = int(match.group(1))
|
|
||||||
# Channel count.
|
|
||||||
match = re.search(r'hz, ([^,]+),', stderr)
|
|
||||||
if match:
|
|
||||||
mode = match.group(1)
|
|
||||||
if mode == 'stereo':
|
|
||||||
n_channels = 2
|
|
||||||
else:
|
|
||||||
match = re.match(r'(\d+) ', mode)
|
|
||||||
n_channels = match and int(match.group(1)) or 1
|
|
||||||
return n_channels, sample_rate
|
|
||||||
|
|
||||||
|
|
||||||
class _CommandBuilder(object):
|
|
||||||
""" A simple builder pattern class for CLI string. """
|
|
||||||
|
|
||||||
def __init__(self, binary):
|
|
||||||
""" Default constructor. """
|
|
||||||
self._command = [binary]
|
|
||||||
|
|
||||||
def flag(self, flag):
|
|
||||||
""" Add flag or unlabelled opt. """
|
|
||||||
self._command.append(flag)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def opt(self, short, value, formatter=str):
|
|
||||||
""" Add option if value not None. """
|
|
||||||
if value is not None:
|
|
||||||
self._command.append(short)
|
|
||||||
self._command.append(formatter(value))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def command(self):
|
|
||||||
""" Build string command. """
|
|
||||||
return self._command
|
|
||||||
|
|
||||||
|
|
||||||
class FFMPEGProcessAudioAdapter(AudioAdapter):
|
class FFMPEGProcessAudioAdapter(AudioAdapter):
|
||||||
""" An AudioAdapter implementation that use FFMPEG binary through
|
""" An AudioAdapter implementation that use FFMPEG binary through
|
||||||
subprocess in order to perform I/O operation for audio processing.
|
subprocess in order to perform I/O operation for audio processing.
|
||||||
@@ -142,17 +42,6 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
|
|||||||
FFMPEG_PATH environment variable.
|
FFMPEG_PATH environment variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
""" Default constructor. """
|
|
||||||
self._ffmpeg_path = _get_ffmpeg_path()
|
|
||||||
|
|
||||||
def _get_command_builder(self):
|
|
||||||
""" Creates and returns a command builder using FFMPEG path.
|
|
||||||
|
|
||||||
:returns: Built command builder.
|
|
||||||
"""
|
|
||||||
return _CommandBuilder(self._ffmpeg_path)
|
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self, path, offset=None, duration=None,
|
self, path, offset=None, duration=None,
|
||||||
sample_rate=None, dtype=np.float32):
|
sample_rate=None, dtype=np.float32):
|
||||||
@@ -168,44 +57,30 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
path = path.decode()
|
path = path.decode()
|
||||||
command = (
|
probe = ffmpeg.probe(path)
|
||||||
self._get_command_builder()
|
if 'streams' not in probe or len(probe['streams']) == 0:
|
||||||
.opt('-ss', offset, formatter=_to_ffmpeg_time)
|
raise IOError('No stream was found with ffprobe')
|
||||||
.opt('-t', duration, formatter=_to_ffmpeg_time)
|
metadata = next(
|
||||||
.opt('-i', path)
|
stream
|
||||||
.opt('-ar', sample_rate)
|
for stream in probe['streams']
|
||||||
.opt('-f', 'f32le')
|
if stream['codec_type'] == 'audio')
|
||||||
.flag('-')
|
n_channels = metadata['channels']
|
||||||
.command())
|
if sample_rate is None:
|
||||||
process = subprocess.Popen(
|
sample_rate = metadata['sample_rate']
|
||||||
command,
|
output_kwargs = {'format': 'f32le', 'ar': sample_rate}
|
||||||
stdout=subprocess.PIPE,
|
if duration is not None:
|
||||||
stderr=subprocess.PIPE)
|
output_kwargs['t'] = _to_ffmpeg_time(duration)
|
||||||
buffer = process.stdout.read(-1)
|
if offset is not None:
|
||||||
# Read STDERR until end of the process detected.
|
output_kwargs['ss'] = _to_ffmpeg_time(offset)
|
||||||
while True:
|
process = (
|
||||||
status = process.stderr.readline()
|
ffmpeg
|
||||||
if not status:
|
.input(path)
|
||||||
raise OSError('Stream info not found')
|
.output('pipe:', **output_kwargs)
|
||||||
if isinstance(status, bytes): # Note: Python 3 compatibility.
|
.run_async(pipe_stdout=True, pipe_stderr=True))
|
||||||
status = status.decode('utf8', 'ignore')
|
buffer, _ = process.communicate()
|
||||||
status = status.strip().lower()
|
|
||||||
if 'no such file' in status:
|
|
||||||
raise IOError(f'File {path} not found')
|
|
||||||
elif 'invalid data found' in status:
|
|
||||||
raise IOError(f'FFMPEG error : {status}')
|
|
||||||
elif 'audio:' in status:
|
|
||||||
n_channels, ffmpeg_sample_rate = _parse_ffmpg_results(status)
|
|
||||||
if sample_rate is None:
|
|
||||||
sample_rate = ffmpeg_sample_rate
|
|
||||||
break
|
|
||||||
# Load waveform and clean process.
|
|
||||||
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
|
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
|
||||||
if not waveform.dtype == np.dtype(dtype):
|
if not waveform.dtype == np.dtype(dtype):
|
||||||
waveform = waveform.astype(dtype)
|
waveform = waveform.astype(dtype)
|
||||||
process.stdout.close()
|
|
||||||
process.stderr.close()
|
|
||||||
del process
|
|
||||||
return (waveform, sample_rate)
|
return (waveform, sample_rate)
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
@@ -225,39 +100,22 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
|
|||||||
if not os.path.exists(directory):
|
if not os.path.exists(directory):
|
||||||
os.makedirs(directory)
|
os.makedirs(directory)
|
||||||
get_logger().debug('Writing file %s', path)
|
get_logger().debug('Writing file %s', path)
|
||||||
# NOTE: Tweak.
|
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
|
||||||
if codec == 'wav':
|
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
|
||||||
codec = None
|
if bitrate:
|
||||||
command = (
|
output_kwargs['audio_bitrate'] = bitrate
|
||||||
self._get_command_builder()
|
if codec is not None and codec != 'wav':
|
||||||
.flag('-y')
|
output_kwargs['codec'] = codec
|
||||||
.opt('-loglevel', 'error')
|
process = (
|
||||||
.opt('-f', 'f32le')
|
ffmpeg
|
||||||
.opt('-ar', sample_rate)
|
.input('pipe:', format='f32le', **input_kwargs)
|
||||||
.opt('-ac', data.shape[1])
|
.output(path, **output_kwargs)
|
||||||
.opt('-i', '-')
|
.overwrite_output()
|
||||||
.flag('-vn')
|
.run_async(pipe_stdin=True, quiet=True))
|
||||||
.opt('-acodec', codec)
|
|
||||||
.opt('-ar', sample_rate) # Note: why twice ?
|
|
||||||
.opt('-strict', '-2') # Note: For 'aac' codec support.
|
|
||||||
.opt('-ab', bitrate)
|
|
||||||
.flag(path)
|
|
||||||
.command())
|
|
||||||
process = subprocess.Popen(
|
|
||||||
command,
|
|
||||||
stdout=open(os.devnull, 'wb'),
|
|
||||||
stdin=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE)
|
|
||||||
# Write data to STDIN.
|
|
||||||
try:
|
try:
|
||||||
process.stdin.write(
|
process.stdin.write(data.astype('<f4').tobytes())
|
||||||
data.astype('<f4').tostring())
|
process.stdin.close()
|
||||||
|
process.wait()
|
||||||
except IOError:
|
except IOError:
|
||||||
raise IOError(f'FFMPEG error: {process.stderr.read()}')
|
raise IOError(f'FFMPEG error: {process.stderr.read()}')
|
||||||
# Clean process.
|
|
||||||
process.stdin.close()
|
|
||||||
if process.stderr is not None:
|
|
||||||
process.stderr.close()
|
|
||||||
process.wait()
|
|
||||||
del process
|
|
||||||
get_logger().info('File %s written', path)
|
get_logger().info('File %s written', path)
|
||||||
|
|||||||
@@ -4,6 +4,8 @@
|
|||||||
""" Utility functions for creating estimator. """
|
""" Utility functions for creating estimator. """
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from os.path import join
|
||||||
|
from tempfile import gettempdir
|
||||||
|
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -15,7 +17,7 @@ from ..model import model_fn
|
|||||||
from ..model.provider import get_default_model_provider
|
from ..model.provider import get_default_model_provider
|
||||||
|
|
||||||
# Default exporting directory for predictor.
|
# Default exporting directory for predictor.
|
||||||
DEFAULT_EXPORT_DIRECTORY = '/tmp/serving'
|
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
||||||
|
|
||||||
|
|
||||||
def create_estimator(params, MWF):
|
def create_estimator(params, MWF):
|
||||||
|
|||||||
@@ -3,12 +3,16 @@
|
|||||||
|
|
||||||
""" Centralized logging facilities for Spleeter. """
|
""" Centralized logging facilities for Spleeter. """
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from os import environ
|
from os import environ
|
||||||
|
|
||||||
__email__ = 'research@deezer.com'
|
__email__ = 'research@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
_FORMAT = '%(levelname)s:%(name)s:%(message)s'
|
||||||
|
|
||||||
|
|
||||||
class _LoggerHolder(object):
|
class _LoggerHolder(object):
|
||||||
""" Logger singleton instance holder. """
|
""" Logger singleton instance holder. """
|
||||||
@@ -16,30 +20,42 @@ class _LoggerHolder(object):
|
|||||||
INSTANCE = None
|
INSTANCE = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tensorflow_logger():
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# pylint: disable=import-error
|
||||||
|
from tensorflow.compat.v1 import logging
|
||||||
|
# pylint: enable=import-error
|
||||||
|
return logging
|
||||||
|
|
||||||
|
|
||||||
def get_logger():
|
def get_logger():
|
||||||
""" Returns library scoped logger.
|
""" Returns library scoped logger.
|
||||||
|
|
||||||
:returns: Library logger.
|
:returns: Library logger.
|
||||||
"""
|
"""
|
||||||
if _LoggerHolder.INSTANCE is None:
|
if _LoggerHolder.INSTANCE is None:
|
||||||
# pylint: disable=import-error
|
formatter = logging.Formatter(_FORMAT)
|
||||||
from tensorflow.compat.v1 import logging
|
handler = logging.StreamHandler()
|
||||||
# pylint: enable=import-error
|
handler.setFormatter(formatter)
|
||||||
_LoggerHolder.INSTANCE = logging
|
logger = logging.getLogger('spleeter')
|
||||||
_LoggerHolder.INSTANCE.set_verbosity(_LoggerHolder.INSTANCE.ERROR)
|
logger.addHandler(handler)
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
logger.setLevel(logging.INFO)
|
||||||
|
_LoggerHolder.INSTANCE = logger
|
||||||
return _LoggerHolder.INSTANCE
|
return _LoggerHolder.INSTANCE
|
||||||
|
|
||||||
|
|
||||||
def enable_logging():
|
def enable_tensorflow_logging():
|
||||||
""" Enable INFO level logging. """
|
""" Enable tensorflow logging. """
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
tf_logger = get_tensorflow_logger()
|
||||||
|
tf_logger.set_verbosity(tf_logger.INFO)
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
logger.set_verbosity(logger.INFO)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def enable_verbose_logging():
|
def enable_logging():
|
||||||
""" Enable DEBUG level logging. """
|
""" Configure default logging. """
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
logger = get_logger()
|
tf_logger = get_tensorflow_logger()
|
||||||
logger.set_verbosity(logger.DEBUG)
|
tf_logger.set_verbosity(tf_logger.ERROR)
|
||||||
|
|||||||
Reference in New Issue
Block a user