diff --git a/batbot/__init__.py b/batbot/__init__.py index c415831..947e939 100644 --- a/batbot/__init__.py +++ b/batbot/__init__.py @@ -17,10 +17,13 @@ output_paths, metadata_path, metadata = spectrogram.compute(filepath) """ -from os.path import exists, join +import concurrent.futures +from multiprocessing import Manager +from os.path import basename, exists, join, splitext from pathlib import Path import pooch +from tqdm import tqdm from batbot import utils @@ -60,10 +63,12 @@ def fetch(pull=False, config=None): def pipeline( filepath, - config=None, - # classifier_thresh=classifier.CONFIGS[None]['thresh'], - clean=True, - output_folder='.', + out_file_stem=None, + output_folder=None, + fast_mode=False, + force_overwrite=False, + quiet=False, + debug=False, ): """ Run the ML pipeline on a given WAV filepath and return the classification results @@ -93,12 +98,138 @@ def pipeline( Returns: tuple ( float, list ( dict ) ): classifier score, list of time windows """ + # Generate spectrogram - output_paths, metadata_path, metadata = spectrogram.compute( - filepath, output_folder=output_folder + output_paths, compressed_paths, metadata_path, metadata = spectrogram.compute( + filepath, + out_file_stem=out_file_stem, + output_folder=output_folder, + fast_mode=fast_mode, + force_overwrite=force_overwrite, + quiet=quiet, + debug=debug, ) - return output_paths, metadata_path + return output_paths, compressed_paths, metadata_path + + +def pipeline_multi_wrapper( + filepaths, + out_file_stems=None, + fast_mode=False, + force_overwrite=False, + worker_position=None, + quiet=False, + tqdm_lock=None, +): + """Fault-tolerant wrapper for multiple inputs. + + Args: + filepaths (_type_): _description_ + out_file_stems (_type_, optional): _description_. Defaults to None. + fast_mode (bool, optional): _description_. Defaults to False. + force_overwrite (bool, optional): _description_. Defaults to False. + + Returns: + _type_: _description_ + """ + + if out_file_stems is not None: + assert len(filepaths) == len( + out_file_stems + ), 'Input filepaths and out_file_stems have different length.' + else: + out_file_stems = [None] * len(filepaths) + + outputs = {'output_paths': [], 'compressed_paths': [], 'metadata_paths': [], 'failed_files': []} + # print(filepaths, out_file_stems) + if tqdm_lock is not None: + tqdm.set_lock(tqdm_lock) + for in_file, out_stem in tqdm( + zip(filepaths, out_file_stems), + desc='Processing, worker {}'.format(worker_position), + position=worker_position, + total=len(filepaths), + leave=True, + ): + try: + output_paths, compressed_paths, metadata_path = pipeline( + in_file, + out_file_stem=out_stem, + fast_mode=fast_mode, + force_overwrite=force_overwrite, + quiet=quiet, + ) + outputs['output_paths'].extend(output_paths) + outputs['compressed_paths'].extend(compressed_paths) + outputs['metadata_paths'].append(metadata_path) + except Exception as e: + outputs['failed_files'].append((str(in_file), e)) + + return tuple(outputs.values()) + + +def parallel_pipeline( + in_file_chunks, + out_stem_chunks=None, + fast_mode=False, + force_overwrite=False, + num_workers=0, + threaded=False, + quiet=False, + desc=None, +): + + if out_stem_chunks is None: + out_stem_chunks = [None] * len(in_file_chunks) + + if len(in_file_chunks) == 0: + return None + else: + assert len(in_file_chunks) == len( + out_stem_chunks + ), 'in_file_chunks and out_stem_chunks must have the same length.' + + if threaded: + executor_cls = concurrent.futures.ThreadPoolExecutor + else: + executor_cls = concurrent.futures.ProcessPoolExecutor + + num_workers = min(len(in_file_chunks), num_workers) + + outputs = {'output_paths': [], 'compressed_paths': [], 'metadata_paths': [], 'failed_files': []} + + lock_manager = Manager() + tqdm_lock = lock_manager.Lock() + + with tqdm(total=len(in_file_chunks), disable=quiet, desc=desc) as progress: + with executor_cls(max_workers=num_workers) as executor: + + futures = [ + executor.submit( + pipeline_multi_wrapper, + filepaths=file_chunk, + out_file_stems=out_stem_chunk, + fast_mode=fast_mode, + force_overwrite=force_overwrite, + worker_position=index % num_workers, + quiet=quiet, + tqdm_lock=tqdm_lock, + ) + for index, (file_chunk, out_stem_chunk) in enumerate( + zip(in_file_chunks, out_stem_chunks) + ) + ] + + for future in concurrent.futures.as_completed(futures): + output_paths, compressed_paths, metadata_path, failed_files = future.result() + outputs['output_paths'].extend(output_paths) + outputs['compressed_paths'].extend(compressed_paths) + outputs['metadata_paths'].extend(metadata_path) + outputs['failed_files'].extend(failed_files) + progress.update(1) + + return tuple(outputs.values()) def batch( @@ -140,7 +271,7 @@ def batch( # Run tiling batch = {} for filepath in filepaths: - _, _, metadata = spectrogram.compute(filepath) + _, _, _, metadata = spectrogram.compute(filepath) batch[filepath] = metadata raise NotImplementedError @@ -164,7 +295,15 @@ def example(): assert exists(wav_filepath) log.debug(f'Running pipeline on WAV: {wav_filepath}') - output = './output' - results = pipeline(wav_filepath, output_folder=output) + + import time + + output_stem = join('output', splitext(basename(wav_filepath))[0]) + start_time = time.time() + results = pipeline( + wav_filepath, out_file_stem=output_stem, fast_mode=False, force_overwrite=True + ) + stop_time = time.time() + print('Example pipeline completed in {} seconds.'.format(stop_time - start_time)) log.debug(results) diff --git a/batbot/batbot.py b/batbot/batbot.py deleted file mode 100755 index fa137e0..0000000 --- a/batbot/batbot.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python -""" -CLI for BatBot -""" -import json -from os.path import exists - -import click - -import batbot -from batbot import log - - -def pipeline_filepath_validator(ctx, param, value): - if not exists(value): - log.error(f'Input filepath does not exist: {value}') - ctx.exit() - return value - - -@click.command('fetch') -@click.option( - '--config', - help='Which ML model to use for inference', - default=None, - type=click.Choice(['usgs']), -) -def fetch(config): - """ - Fetch the required machine learning ONNX model for the classifier - """ - batbot.fetch(config=config) - - -@click.command('pipeline') -@click.argument( - 'filepath', - nargs=1, - type=str, - callback=pipeline_filepath_validator, -) -@click.option( - '--config', - help='Which ML model to use for inference', - default=None, - type=click.Choice(['usgs']), -) -@click.option( - '--output', - 'output_path', - help='Path to output folder for the results', - default='.', - type=str, -) -# @click.option( -# '--classifier_thresh', -# help='Classifier confidence threshold', -# default=int(classifier.CONFIGS[None]['thresh'] * 100), -# type=click.IntRange(0, 100, clamp=True), -# ) -def pipeline( - filepath, - config, - output_path, - # classifier_thresh, -): - """ - Run the BatBot pipeline on an input WAV filepath. An example output of the JSON - can be seen below. - - .. code-block:: javascript - - { - '/path/to/file.wav': { - 'classifier': 0.5, - } - } - """ - if config is not None: - config = config.strip().lower() - # classifier_thresh /= 100.0 - - batbot.pipeline( - filepath, - config=config, - # classifier_thresh=classifier_thresh, - output_folder=output_path, - ) - - -@click.command('batch') -@click.argument( - 'filepaths', - nargs=-1, - type=str, -) -@click.option( - '--config', - help='Which ML model to use for inference', - default=None, - type=click.Choice(['usgs']), -) -@click.option( - '--output', - help='Path to output JSON (if unspecified, results are printed to screen)', - default=None, - type=str, -) -# @click.option( -# '--classifier_thresh', -# help='Classifier confidence threshold', -# default=int(classifier.CONFIGS[None]['thresh'] * 100), -# type=click.IntRange(0, 100, clamp=True), -# ) -def batch( - filepaths, - config, - output, - # classifier_thresh, -): - """ - Run the BatBot pipeline in batch on a list of input WAV filepaths. - An example output of the JSON can be seen below. - - .. code-block:: javascript - - { - '/path/to/file1.wav': { - 'classifier': 0.5, - }, - '/path/to/file2.wav': { - 'classifier': 0.8, - }, - ... - } - """ - if config is not None: - config = config.strip().lower() - # classifier_thresh /= 100.0 - - log.debug(f'Running batch on {len(filepaths)} files...') - - score_list = batbot.batch( - filepaths, - config=config, - # classifier_thresh=classifier_thresh, - ) - - data = {} - for filepath, score in zip(filepaths, score_list): - data[filepath] = { - 'classifier': score, - } - - log.debug('Outputting results...') - if output: - with open(output, 'w') as outfile: - json.dump(data, outfile) - else: - print(data) - - -@click.command('example') -def example(): - """ - Run a test of the pipeline on an example WAV with the default configuration. - """ - batbot.example() - - -@click.group() -def cli(): - """ - BatBot CLI - """ - pass - - -cli.add_command(fetch) -cli.add_command(pipeline) -cli.add_command(batch) -cli.add_command(example) - - -if __name__ == '__main__': - cli() diff --git a/batbot/batbot_cli.py b/batbot/batbot_cli.py new file mode 100755 index 0000000..eb24588 --- /dev/null +++ b/batbot/batbot_cli.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python +""" +CLI for BatBot +""" +import json +import pprint +import warnings +from glob import glob +from os import getcwd, makedirs, remove +from os.path import ( + basename, + commonpath, + exists, + isdir, + isfile, + join, + relpath, + split, + splitext, +) + +import click +import numpy as np +from tqdm import tqdm + +import batbot +from batbot import log + + +def pipeline_filepath_validator(ctx, param, value): + if not exists(value): + log.error(f'Input filepath does not exist: {value}') + ctx.exit() + return value + + +@click.command('fetch') +@click.option( + '--config', + help='Which ML model to use for inference', + default=None, + type=click.Choice(['usgs']), +) +def fetch(config): + """ + Fetch the required machine learning ONNX model for the classifier + """ + batbot.fetch(config=config) + + +@click.command('pipeline') +@click.argument( + 'filepath', + nargs=1, + type=str, + callback=pipeline_filepath_validator, +) +# @click.option( +# '--config', +# help='Which ML model to use for inference', +# default=None, +# type=click.Choice(['usgs']), +# ) +@click.option( + '--output', + 'output_path', + help='Path to output folder for the results', + default='.', + type=str, +) +# @click.option( +# '--classifier_thresh', +# help='Classifier confidence threshold', +# default=int(classifier.CONFIGS[None]['thresh'] * 100), +# type=click.IntRange(0, 100, clamp=True), +# ) +def pipeline( + filepath, + # config, + output_path, + # classifier_thresh, +): + + batbot.pipeline( + filepath, + # config=config, + # classifier_thresh=classifier_thresh, + output_folder=output_path, + ) + + +@click.command('preprocess') +@click.argument( + 'filepaths', + nargs=-1, + type=str, +) +@click.option( + '--output-dir', + '-o', + help='Processed file root output directory. Outputs will attempt to mirror input file directory structure if given multiple inputs (unless --no-file-structure flag is given). Defaults to current working directory.', + nargs=1, + default='.', + type=str, +) +@click.option( + '--process-metadata', + '-m', + help='Use a slower version of the pipeline which increases spectogram compression quality and also outputs bat call metadata.', + is_flag=True, +) +@click.option( + '--force-overwrite', + '-f', + help='Force overwriting of compressed spectrogram and other output files.', + is_flag=True, +) +@click.option( + '--num-workers', + '-n', + help='Number of parallel workers to use. Set to zero for serial computation only.', + nargs=1, + default=0, + type=int, +) +@click.option( + '--output-json', + help='Path to output JSON (if unspecified, output file locations are printed to screen).', + default=None, + type=str, +) +@click.option( + '--dry-run', + '-d', + help='List out all the audio files to be loaded and all the anticipated output files. Additionally lists all "extra" files in the output directory that would be deleted if using the --cleanup flag.', + is_flag=True, +) +@click.option( + '--cleanup', + help='For the given input filepaths and --output-dir arguments, delete any extra files that would not have been created by the batbot preprocess. Skips hidden files starting with ".". Acts as if --force-overwrite flag is given (does not delete existing, preprocessed outputs). WARNING: This will delete files, recommend running with the --dry-run flag first and carefully examining the output!', + is_flag=True, +) +@click.option( + '--no-file-structure', + help='(Not recommended) Turn off input file directory structure mirroring. All outputs will be written directly into the provided output dir. WARNING: If multiple input files have the same filename, outputs will overwrite!', + is_flag=True, +) +def preprocess( + filepaths, + output_dir, + process_metadata, + force_overwrite, + num_workers, + output_json, + dry_run, + cleanup, + no_file_structure, +): + """Generate compressed spectrogram images for wav files into the current working directory. + Takes one or more space separated arguments of filepaths to process. If given a directory name, + will recursively search through the directory and all subfolders to find all contained *.wav files. + Alternatively, the argument can be given as a string using wildcard ** for folders and/or * in filenames + (if ** wildcard is used, will recursively search through all subfolders). + + \b + Examples: + batbot preprocess ../data -o ./tmp + batbot preprocess "../data/**/*.wav" + batbot preprocess ../data -o ./tmp -n 32 + batbot preprocess ../data -o ./tmp -n 32 -fm + batbot preprocess ../data -o ./tmp -f --dry-run --output-json dry_run.json + batbot preprocess ../data -o ./tmp --cleanup + """ + in_filepaths = [] + for file in filepaths: + if isdir(file): + in_filepaths.extend(glob(join(file, '**/*.wav'), recursive=True)) + elif isfile(file): + in_filepaths.append(file) + else: + in_filepaths.extend(glob(file, recursive=True)) + # remove any repeats + in_filepaths = sorted(list(set(in_filepaths))) + + if len(in_filepaths) == 0: + print('Found no files given filepaths input {}'.format(filepaths)) + return + + # set up output paths for each input path + root_inpath = commonpath(in_filepaths) + root_outpath = '.' if output_dir is None else output_dir + makedirs(root_outpath, exist_ok=True) + if no_file_structure: + out_filepath_stems = [join(root_outpath, splitext(x)[0]) for x in in_filepaths] + else: + out_filepath_stems = [ + splitext(join(root_outpath, relpath(x, root_inpath)))[0] for x in in_filepaths + ] + new_dirs = [split(x)[0] for x in out_filepath_stems] + for new_dir in set(new_dirs): + makedirs(new_dir, exist_ok=True) + + # look for existing output files and remove from the set + in_filepaths = np.array(in_filepaths) + if dry_run or cleanup: + # save copy of all outputs before removing already processed data + out_filepath_stems_all = out_filepath_stems.copy() + out_filepath_stems = np.array(out_filepath_stems) + if not force_overwrite: + idx_remove = np.full((len(in_filepaths),), False) + for ii, out_file_stem in enumerate(out_filepath_stems): + test_file = '{}.*'.format(out_file_stem) + test_glob = glob(test_file) + if len(test_glob) > 0: + idx_remove[ii] = True + in_filepaths = in_filepaths[np.invert(idx_remove)] + out_filepath_stems = out_filepath_stems[np.invert(idx_remove)] + n_skipped = sum(idx_remove) + if len(in_filepaths) == 0: + print( + 'Found no unprocessed files given filepaths input {} and output directory "{}" after skipping {} files'.format( + filepaths, root_outpath, n_skipped + ) + ) + print('If desired, use --force-overwrite flag to overwrite existing processed data') + return + + if dry_run or cleanup: + # Find all "extra" files that would be deleted in cleanup mode + all_files = set(glob(join(root_outpath, '**/*'), recursive=True)) + for out_stem in out_filepath_stems_all: + out_files = glob('{}.*'.format(out_stem)) + all_files -= set(out_files) + dir_files = [] + # remove directories + for file in all_files: + if isdir(file): + dir_files.append(file) + all_files -= set(dir_files) + extra_files = all_files + + print('Located {} total unprocessed files'.format(len(in_filepaths))) + print('\tFast processing mode {}'.format('OFF' if process_metadata else 'ON')) + if process_metadata: + print('\t\tFull bat call metadata will be produced') + print('\tForce output overwrite {}'.format('ON' if force_overwrite else 'OFF')) + if not force_overwrite: + print('\t\tSkipped {} files with already preprocessed outputs'.format(n_skipped)) + print('\tNum parallel workers: {}'.format(num_workers)) + if no_file_structure: + print('\tFlattening output file structure') + print('\tCurrent working dir: {}'.format(getcwd())) + print('\tOutput root dir: {}'.format(output_dir)) + print( + '\tFirst input file -> output files: {} -> {}.*'.format( + in_filepaths[0], out_filepath_stems[0] + ) + ) + if len(in_filepaths) > 2: + print( + '\tLast input file -> output files: {} -> {}.*'.format( + in_filepaths[-1], out_filepath_stems[-1] + ) + ) + + if dry_run: + # Print out files to be processed, anticipated outputs, and files that would be deleted in cleanup mode. + print('\nDry run mode active - skipping all processing') + data = {} + data['input file, output file stem'] = [ + (str(x), '{}.*'.format(y)) for x, y in zip(in_filepaths, out_filepath_stems) + ] + data['files to be deleted in cleanup'] = list(extra_files) + if output_json is None: + pprint.pp(data) + else: + with open(output_json, 'w') as outfile: + json.dump(data, outfile, indent=4) + print('Outputs written to {}'.format(output_json)) + print('Complete.') + return + + if cleanup: + print('\nCleanup mode active - skipping all processing') + if len(extra_files) == 0: + print('No files to delete') + else: + usr_in = input( + 'Found {} files to delete (recommend to see details by running with --dry-run flag). Continue (y/n)? '.format( + len(extra_files) + ) + ) + if usr_in.lower() not in ['y', 'yes']: + print('Aborting cleanup mode.') + return + for file in extra_files: + print('Deleting file: {}'.format(file)) + remove(file) + print('Complete.') + return + + # Begin execution loop. + data = {'output_path': [], 'compressed_path': [], 'metadata_path': [], 'failed_files': []} + if num_workers is None or num_workers == 0: + + # Serial execution. + for file, out_stem in tqdm( + zip(in_filepaths, out_filepath_stems), + desc='Preprocessing files', + total=len(in_filepaths), + ): + try: + output_paths, compressed_paths, metadata_path = batbot.pipeline( + file, + out_file_stem=out_stem, + fast_mode=(not process_metadata), + force_overwrite=force_overwrite, + quiet=True, + ) + data['output_path'].extend(output_paths) + data['compressed_path'].extend(compressed_paths) + if process_metadata: + data['metadata_path'].append(metadata_path) + except Exception as e: + warnings.warn('WARNING: Pipeline failed for file {}'.format(file)) + data['failed_files'].append((str(file), e)) + else: + # Parallel execution. + # shuffle input and output paths + zipped = np.stack((in_filepaths, out_filepath_stems), axis=-1) + np.random.seed(0) + np.random.shuffle(zipped) + assert all( + [ + x in zipped[:, 0] and y in zipped[:, 1] + for x, y in zip(in_filepaths, out_filepath_stems) + ] + ) + in_filepaths, out_filepath_stems = zipped.T + assert all([basename(y) in basename(x) for x, y in zip(in_filepaths, out_filepath_stems)]) + + # make num_workers chunks + in_file_chunks = np.array_split(in_filepaths, num_workers) + out_stem_chunks = np.array_split(out_filepath_stems, num_workers) + + # send to parallel function + output_paths, compressed_paths, metadata_paths, failed_files = batbot.parallel_pipeline( + in_file_chunks=in_file_chunks, + out_stem_chunks=out_stem_chunks, + fast_mode=(not process_metadata), + force_overwrite=force_overwrite, + num_workers=num_workers, + threaded=False, + quiet=True, + desc='Preprocessing chunks of files with {} workers'.format(num_workers), + ) + data['output_path'].extend(output_paths) + data['compressed_path'].extend(compressed_paths) + if process_metadata: + data['metadata_path'].extend(metadata_paths) + data['failed_files'].extend(failed_files) + + if output_json is None: + print('\nFull spectrogram output paths:') + pprint.pp(sorted(data['output_path'])) + print('\nCompressed spectrogram output paths:') + pprint.pp(sorted(data['compressed_path'])) + if process_metadata: + print('\nProcessed metadata paths:') + pprint.pp(sorted(data['metadata_path'])) + print('\nFiles skipped due to failure, and corresponding exceptions:') + pprint.pp(sorted(data['failed_files'])) + else: + with open(output_json, 'w') as outfile: + json.dump(data, outfile, indent=4) + print('Outputs written to {}'.format(output_json)) + print('\nComplete.') + + return data + + +@click.command('batch') +@click.argument( + 'filepaths', + nargs=-1, + type=str, +) +@click.option( + '--config', + help='Which ML model to use for inference', + default=None, + type=click.Choice(['usgs']), +) +@click.option( + '--output', + help='Path to output JSON (if unspecified, results are printed to screen)', + default=None, + type=str, +) +# @click.option( +# '--classifier_thresh', +# help='Classifier confidence threshold', +# default=int(classifier.CONFIGS[None]['thresh'] * 100), +# type=click.IntRange(0, 100, clamp=True), +# ) +def batch( + filepaths, + config, + output, + # classifier_thresh, +): + """ + Run the BatBot pipeline in batch on a list of input WAV filepaths. + An example output of the JSON can be seen below. + + .. code-block:: javascript + + { + '/path/to/file1.wav': { + 'classifier': 0.5, + }, + '/path/to/file2.wav': { + 'classifier': 0.8, + }, + ... + } + """ + if config is not None: + config = config.strip().lower() + # classifier_thresh /= 100.0 + + log.debug(f'Running batch on {len(filepaths)} files...') + + score_list = batbot.batch( + filepaths, + config=config, + # classifier_thresh=classifier_thresh, + ) + + data = {} + for filepath, score in zip(filepaths, score_list): + data[filepath] = { + 'classifier': score, + } + + log.debug('Outputting results...') + if output: + with open(output, 'w') as outfile: + json.dump(data, outfile, indent=4) + else: + print(data) + + +@click.command('example') +def example(): + """ + Run a test of the pipeline on an example WAV with the default configuration. + """ + batbot.example() + + +@click.group() +def cli(): + """ + BatBot CLI + """ + pass + + +cli.add_command(fetch) +cli.add_command(pipeline) +cli.add_command(preprocess) +cli.add_command(batch) +cli.add_command(example) + + +if __name__ == '__main__': + cli() diff --git a/batbot/spectrogram/__init__.py b/batbot/spectrogram/__init__.py index f1fa0f4..99fb956 100644 --- a/batbot/spectrogram/__init__.py +++ b/batbot/spectrogram/__init__.py @@ -4,32 +4,31 @@ import os import shutil import warnings -from os.path import basename, exists, join, splitext +from glob import glob +from os.path import basename, exists, join, split, splitext import cv2 import librosa -import librosa.display import matplotlib.pyplot as plt # import networkx as nx import numpy as np import pyastar2d -import scipy.signal # Ensure this is at the top with other imports +import scipy.stats import tqdm -from line_profiler import LineProfiler + +# from line_profiler import LineProfiler from scipy import ndimage # from PIL import Image -from scipy.ndimage import gaussian_filter1d - -# from scipy.ndimage.filters import maximum_filter1d +from scipy.ndimage import gaussian_filter1d, median_filter from shapely.geometry import Point from shapely.geometry.polygon import Polygon from skimage import draw, measure from batbot import log -lp = LineProfiler() +# lp = LineProfiler() FREQ_MIN = 5e3 @@ -38,11 +37,13 @@ def compute(*args, **kwargs): retval = compute_wrapper(*args, **kwargs) - lp.print_stats() + # if not kwargs.get('fast_mode', True) and not kwargs.get('quiet', True): + # lp.print_stats() return retval def get_islands(data): + # Find all islands of contiguous same elements with island length 2 or more mask = np.r_[np.diff(data) == 0, False] mask_ = np.concatenate(([False], mask, [False])) idx_ = np.flatnonzero(mask_[1:] != mask_[:-1]) @@ -50,7 +51,7 @@ def get_islands(data): def get_slope_islands(slope_flags): - flags = slope_flags.astype(np.uint8) + flags = slope_flags.astype(np.uint16) islands = get_islands(flags) idx = int(np.argmax([val.sum() for val in islands])) islands = [val * (1 if i == idx else 0) for i, val in enumerate(islands)] @@ -112,6 +113,7 @@ def plot_histogram( ignore_zeros=False, max_val=None, smoothing=128, + min_band_idx=None, csum_threshold=0.95, output_path='.', output_filename='histogram.png', @@ -119,6 +121,9 @@ def plot_histogram( if max_val is None: max_val = int(image.max()) + if min_band_idx is not None: + image = image[min_band_idx:, :] + if ignore_zeros: image = image[image > 0] @@ -133,10 +138,12 @@ def plot_histogram( # if ignore_zeros: # assert hist[0] == 0 - hist_original = hist.copy() + if output_path: + hist_original = hist.copy() if smoothing: hist = gaussian_filter1d(hist, smoothing, mode='nearest') - hist_original = (hist_original / hist_original.max()) * hist.max() + if output_path: + hist_original = (hist_original / hist_original.max()) * hist.max() mode_ = np.argmax(hist) # histogram mode @@ -237,28 +244,40 @@ def generate_waveplot( return waveplot +# @lp def load_stft( - wav_filepath, sr=250e3, n_fft=512, window='blackmanharris', win_length=256, hop_length=16 + wav_filepath, + sr=250e3, + n_fft=512, + window='blackmanharris', + win_length=256, + hop_length=16, + fast_mode=False, ): assert exists(wav_filepath) log.debug(f'Computing spectrogram on {wav_filepath}') # Load WAV file try: - waveform_, sr_ = librosa.load(wav_filepath, sr=None) - duration = len(waveform_) / sr_ + waveform_, orig_sr = librosa.load(wav_filepath, sr=None) + duration = len(waveform_) / orig_sr except Exception as e: raise OSError(f'Error loading file: {e}') # Resample the waveform - waveform = librosa.resample(waveform_, orig_sr=sr_, target_sr=sr) + waveform = librosa.resample(waveform_, orig_sr=orig_sr, target_sr=sr) + + # TODO: signal processing: remove DC offset, time window edges of waveform # Convert the waveform to a (complex) spectrogram stft = librosa.stft( waveform, n_fft=n_fft, window=window, win_length=win_length, hop_length=hop_length ) # Convert the complex power (amplitude + phase) into amplitude (decibels) - stft_db = librosa.power_to_db(np.abs(stft) ** 2, ref=np.max) + # Do not threshold the data - threshold will be applied later + # stft_db = librosa.power_to_db(np.abs(stft) ** 2, ref=np.max, top_db=np.inf) # OLD method, cuts off lower values + abs_sq_stft = np.square(np.abs(stft)) + stft_db = 10 * np.log10(abs_sq_stft / abs_sq_stft.max() + 1e-20) # Retrieve time vector in seconds corresponding to STFT time_vec = librosa.frames_to_time( range(stft_db.shape[1]), sr=sr, hop_length=hop_length, n_fft=n_fft @@ -282,12 +301,23 @@ def load_stft( stft_db = stft_db[min_index : max_index + 1, :] bands = bands[min_index : max_index + 1] - waveplot = generate_waveplot(waveform, stft_db, hop_length=hop_length) + if fast_mode: + waveplot = [] + else: + waveplot = generate_waveplot(waveform, stft_db, hop_length=hop_length) - return stft_db, waveplot, sr, bands, duration, min_index, time_vec + # Estimate maximum frequency band containing data based on original sample rate + # Only data up to this maximum band should be used when computing statistics + max_band_idx = min((int(np.where(bands < orig_sr / 2.02)[0][-1]), len(bands) - 1)) + # set non-physical noise above the max band to a minimum value + if max_band_idx < len(bands) - 1: + stft_db[max_band_idx + 1 :, :] = np.min(stft_db[: max_band_idx + 1, :]) + return stft_db, waveplot, sr, bands, duration, min_index, time_vec, orig_sr, max_band_idx -def gain_stft(stft_db, gain_db=80.0, autogain_stddev=5.0): + +# @lp +def gain_stft(stft_db, gain_db=120.0, autogain_stddev=5.0, max_band_idx=None): # Subtract per-frequency median DB med = np.median(stft_db, axis=1).reshape(-1, 1) stft_db -= med @@ -300,7 +330,10 @@ def gain_stft(stft_db, gain_db=80.0, autogain_stddev=5.0): # Calculate the non-zero median DB and MAD # autogain signal if (median - alpha * deviation) is higher than provided gain - temp = stft_db[stft_db > 0] + if max_band_idx is not None: + temp = stft_db[: max_band_idx + 1, :][stft_db[: max_band_idx + 1, :] > 0] + else: + temp = stft_db[stft_db > 0] med_db = np.median(temp) std_db = scipy.stats.median_abs_deviation(temp, axis=None, scale='normal') autogain_value = med_db - (autogain_stddev * std_db) @@ -361,9 +394,12 @@ def calculate_window_and_stride( return window, stride -def create_coarse_candidates(stft_db, window, stride, threshold_stddev=3.0): +def create_coarse_candidates(stft_db, window, stride, threshold_stddev=3.0, min_band_idx=None): # Re-calculate the non-zero median DB and MAD (scaled like std) - temp = stft_db[stft_db > 0] + if min_band_idx is not None: + temp = stft_db[min_band_idx:, :][stft_db[min_band_idx:, :] > 0] + else: + temp = stft_db[stft_db > 0] med_db = np.median(temp) std_db = scipy.stats.median_abs_deviation(temp, axis=None, scale='normal') threshold = med_db + (threshold_stddev * std_db) @@ -389,24 +425,39 @@ def create_coarse_candidates(stft_db, window, stride, threshold_stddev=3.0): return candidates, candidate_dbs +# @lp def filter_candidates_to_ranges( - stft_db, candidates, window=16, skew_stddev=2.0, area_percent=0.10, output_path=None + stft_db, + candidates, + window=16, + skew_stddev=2.0, + area_percent=0.10, + min_band_idx=None, + output_path=None, + fast_mode=False, + quiet=False, ): # Filter the candidates based on their distribution skewness - stride_ = 2 + stride_ = 2 if not fast_mode else 4 buffer = int(round(window / stride_ / 2)) reject_idxs = [] ranges = [] - for index, (idx, start, stop) in tqdm.tqdm(list(enumerate(candidates))): + for index, (idx, start, stop) in tqdm.tqdm(list(enumerate(candidates)), disable=quiet): # Extract the candidate window of the STFT - candidate = stft_db[:, start:stop] + if min_band_idx is not None: + candidate = stft_db[min_band_idx:, start:stop] + else: + candidate = stft_db[:, start:stop] # Create a vertical (frequency) sliding window of Numpy views views = np.lib.stride_tricks.sliding_window_view(candidate, (window, candidate.shape[1]))[ ::stride_, 0 ] - skews = scipy.stats.skew(views, axis=(1, 2)) + with warnings.catch_warnings(): + # handle cases with mono-valued data + warnings.simplefilter('ignore', category=RuntimeWarning) + skews = scipy.stats.skew(views, axis=(1, 2)) # Center and clip the skew values skew_thresh = calculate_mean_within_stddev_window(skews, skew_stddev) @@ -415,10 +466,13 @@ def filter_candidates_to_ranges( skews = normalize_skew(skews, skew_thresh) # Calculate the largest contiguous island of non-zeros - skews = (skews > 0).astype(np.uint8) + skews = (skews > 0).astype(np.uint16) islands = get_islands(skews) area = float(max([val.sum() for val in islands])) area /= len(skews) + if area == 0.0 and sum(skews) >= 1: + # handle edge case with single-element islands + area = 1.0 / len(skews) if area >= area_percent: ranges.append((start, stop)) @@ -515,7 +569,7 @@ def normalize_skew(skews, skew_thresh): with warnings.catch_warnings(): warnings.simplefilter('ignore', category=RuntimeWarning) - skews /= skews.max() + skews /= np.nanmax(skews) skews = np.nan_to_num(skews, nan=0.0, posinf=0.0, neginf=-0.0) @@ -524,15 +578,22 @@ def normalize_skew(skews, skew_thresh): def calculate_mean_within_stddev_window(values, window): # Calculate the average skew within X standard deviations (temperature scaling) - values_mean = np.mean(values) - values_std = np.std(values) + values_mean = np.nanmean(values) + values_std = np.nanstd(values) values_flags = np.abs(values - values_mean) <= (values_std * window) values_mean_windowed = values[values_flags].mean() return values_mean_windowed def tighten_ranges( - stft_db, ranges, window, duration, skew_stddev=2.0, min_duration_ms=2.0, output_path='.' + stft_db, + ranges, + window, + duration, + skew_stddev=2.0, + min_duration_ms=2.0, + output_path='.', + quiet=False, ): minimum_duration = int(np.around(stft_db.shape[1] / (duration * 1e3) * min_duration_ms)) @@ -541,15 +602,18 @@ def tighten_ranges( buffer = int(round(window / stride_ / 2)) ranges_ = [] - for index, (start, stop) in tqdm.tqdm(list(enumerate(ranges))): + for index, (start, stop) in tqdm.tqdm(list(enumerate(ranges)), disable=quiet): # Extract the candidate window of the STFT candidate = stft_db[:, start:stop] - # Create a vertical (frequency) sliding window of Numpy views + # Create a horizontal (time) sliding window of Numpy views views = np.lib.stride_tricks.sliding_window_view(candidate, (candidate.shape[0], window))[ 0, ::stride_ ] - skews = scipy.stats.skew(views, axis=(1, 2)) + with warnings.catch_warnings(): + # handle cases with mono-valued data + warnings.simplefilter('ignore', category=RuntimeWarning) + skews = scipy.stats.skew(views, axis=(1, 2)) # Center and clip the skew values skew_thresh = calculate_mean_within_stddev_window(skews, skew_stddev) @@ -557,7 +621,7 @@ def tighten_ranges( # Calculate the largest contiguous island of non-zeros skew_flags = skews > 0 - skews = skew_flags.astype(np.uint8) + skews = skew_flags.astype(np.uint16) islands = get_islands(skews) islands = [(index + 1) * val for index, val in enumerate(islands)] island = np.hstack(islands) @@ -698,7 +762,7 @@ def threshold_contour(segment, index, output_path='.'): def filter_contour(segment, index, med_db=None, std_db=None, kernel=5, output_path='.'): # segment = cv2.erode(segment, np.ones((3, 3), np.uint8), iterations=1) - segment = scipy.signal.medfilt(segment, kernel_size=kernel) + segment = median_filter(segment, size=(kernel, kernel), mode='reflect') if None not in {med_db, std_db}: segment_threshold = med_db - std_db @@ -1067,13 +1131,14 @@ def significant_contour_path( return bandwidth, duration, significant -def scale_pdf_contour(segment, index, output_path='.'): +def scale_pdf_contour(segment, index, min_band_idx=None, output_path='.'): segment = normalize_stft(segment, None, segment.dtype) med_db, std_db, peak_db = plot_histogram( segment, smoothing=512, ignore_zeros=True, csum_threshold=0.9, + min_band_idx=min_band_idx, output_path=output_path, output_filename=f'contour.{index}.00.histogram.png', ) @@ -1171,9 +1236,15 @@ def find_contour_and_peak( # (note that these were computed prior to CDF weighting) threshold = peak_db - threshold_std * peak_db_std + # pad all edges to handle signal that butts up against segment edges + segment_pad = np.pad(segment, ((2, 2), (2, 2))) contours = measure.find_contours( - segment, level=threshold, fully_connected='high', positive_orientation='high' + segment_pad, level=threshold, fully_connected='high', positive_orientation='high' ) + # remove padding in output contour + for contour in contours: + contour[:, 0] -= 2 + contour[:, 1] -= 2 # Display the image and plot all contours found if output_path: @@ -1203,6 +1274,11 @@ def find_contour_and_peak( contour_ = np.vstack((y, x), dtype=contour.dtype).T polygon_ = Polygon(contour).convex_hull + + # Add small buffer to smoothed contour be sure to include maximum value location. + polygon = Polygon(contour).buffer(1.0) + xx, yy = polygon.exterior.coords.xy + contour_ = np.vstack((xx, yy)).T assert idx not in counter counter[idx] = (found, polygon_) @@ -1253,11 +1329,17 @@ def calculate_harmonic_and_echo_flags( write_contour_debug_image(negative_, index, 7, 'negative', output_path=output_path) negative_skew = scipy.stats.skew(original[np.logical_and(nonzeros, negative)]) - harmonic_skew = scipy.stats.skew(original[np.logical_and(nonzeros, harmonic)]) - negative_skew - echo_skew = ( - scipy.stats.skew(original[np.logical_and(np.logical_and(nonzeros, echo), ~harmonic)]) - - negative_skew - ) + with warnings.catch_warnings(): + # allow for nan outputs in cases of empty or mono-valued selections + warnings.simplefilter('ignore', category=RuntimeWarning) + selection = np.logical_and(nonzeros, harmonic) + harmonic_skew = scipy.stats.skew(original[selection]) - negative_skew + selection = np.logical_and(np.logical_and(nonzeros, echo), ~harmonic) + echo_skew = scipy.stats.skew(original[selection]) - negative_skew + if np.isnan(harmonic_skew): + harmonic_skew = -np.inf + if np.isnan(echo_skew): + echo_skew = -np.inf skew_thresh = np.abs(negative_skew * 0.1) harmonic_flag = harmonic_skew >= skew_thresh @@ -1298,9 +1380,19 @@ def calculate_harmonic_and_echo_flags( return harmonic_flag, harmonic_peak, echo_flag, echo_peak -@lp +# @lp def compute_wrapper( - wav_filepath, annotations=None, output_folder='.', bitdepth=16, debug=False, **kwargs + wav_filepath, + out_file_stem=None, + output_folder=None, + fast_mode=False, + force_overwrite=False, + quiet=False, + annotations=None, + bitdepth=16, + mask_secondary_effects=False, + debug=False, + **kwargs, ): """ Compute the spectrograms for a given input WAV and saves them to disk. @@ -1322,25 +1414,51 @@ def compute_wrapper( - tuple of spectrogram's (min, max) frequency - list of spectrogram filepaths, split by 50k horizontal pixels """ - base = splitext(basename(wav_filepath))[0] + if not force_overwrite: + test_file = '{}.*'.format(out_file_stem) + test_glob = glob(test_file) + if len(test_glob) > 0: + if not quiet: + print( + 'NOTE: Found existing file(s) at {} with force_overwrite=False. Skipping file.'.format( + test_file + ) + ) + return [], [], [], {} + if fast_mode: + bitdepth = 8 + quiet = True assert bitdepth in [8, 16] dtype = np.uint8 if bitdepth == 8 else np.uint16 chunksize = int(50e3) - # create output folder if it doesn't exist + # Default to retrieving the output_folder from out_file_stem + if out_file_stem is not None: + output_folder = split(out_file_stem)[0] + if output_folder is None: + output_folder = './output' + # If no out_file_stem is given, default to the wav file basename joined with output_folder + if out_file_stem is None: + out_file_stem = join(output_folder, splitext(basename(wav_filepath))[0]) + + debug_path = get_debug_path(output_folder, wav_filepath, enabled=debug) + # Create output folder if it doesn't exist if not os.path.exists(output_folder): os.makedirs(output_folder) assert exists(output_folder) - debug_path = get_debug_path(output_folder, wav_filepath, enabled=debug) - # Load the spectrogram from a WAV file on disk - stft_db, waveplot, sr, bands, duration, freq_offset, time_vec = load_stft(wav_filepath) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=DeprecationWarning) + # ignore warning due to aifc deprecation + stft_db, waveplot, sr, bands, duration, freq_offset, time_vec, orig_sr, max_band_idx = ( + load_stft(wav_filepath, fast_mode=fast_mode) + ) # Apply a dynamic range to a fixed dB range - stft_db = gain_stft(stft_db) + stft_db = gain_stft(stft_db, max_band_idx=max_band_idx) # Bin the floating point data to X-bit integers (X=8 or X=16) stft_db = normalize_stft(stft_db, None, dtype) @@ -1352,24 +1470,66 @@ def compute_wrapper( y_step_freq = float(bands[0] - bands[1]) x_step_ms = float(1e3 * (time_vec[1] - time_vec[0])) bands = np.around(bands).astype(np.int32).tolist() + min_band_idx = len(bands) - max_band_idx - 1 + + if not fast_mode: + # Plot the histogram, ignoring any non-zero values (will no-op if output_path is None) + global_med_db, global_std_db, global_peak_db = plot_histogram( + stft_db, + ignore_zeros=True, + smoothing=512, + min_band_idx=min_band_idx, + output_path=debug_path, + ) + # Estimate a global threshold for finding the edges of bat call contours + global_threshold_std = 2.0 + global_threshold = global_peak_db - global_threshold_std * global_std_db + else: + # Fast mode skips bat call segmentation + global_threshold = 0.0 - # # Save the spectrogram image to disk - # cv2.imwrite('debug.tif', stft_db, [cv2.IMWRITE_TIFF_COMPRESSION, 1]) - - # Plot the histogram, ignoring any non-zero values (will no-op if output_path is None) - global_med_db, global_std_db, global_peak_db = plot_histogram( - stft_db, ignore_zeros=True, smoothing=512, output_path=debug_path + # Get a distribution of the max candidate locations + # Normal mode uses a relatively large window and little overlap + # Fast mode uses a relatively small window and lots of overlap, since it skips range tightening step + strides_per_window = 3 if not fast_mode else 16 + window_size_ms = 12 if not fast_mode else 3 + threshold_stddev = 3.0 if not fast_mode else 4.0 + window, stride = calculate_window_and_stride( + stft_db, + duration, + window_size_ms=window_size_ms, + strides_per_window=strides_per_window, + time_vec=time_vec, + ) + candidates, candidate_max_dbs = create_coarse_candidates( + stft_db, + window, + stride, + threshold_stddev=threshold_stddev, + min_band_idx=min_band_idx, ) - # Estimate a global threshold for finding the edges of bat call contours - global_threshold_std = 2.0 - global_threshold = global_peak_db - global_threshold_std * global_std_db - # Get a distribution of the max candidate locations - window, stride = calculate_window_and_stride(stft_db, duration, time_vec=time_vec) - candidates, candidate_max_dbs = create_coarse_candidates(stft_db, window, stride) + if fast_mode: + # combine candidates for efficiency and remove very short candidates (likely noise) + tmp_ranges = [(x, y) for _, x, y in candidates] + tmp_ranges = merge_ranges(tmp_ranges, stft_db.shape[1]) + candidate_lengths = np.array([y - x for x, y in tmp_ranges]) + length_thresh = window * 1.5 + idx_remove = candidate_lengths < length_thresh + candidates = [(ii, x, y) for ii, (x, y) in enumerate(tmp_ranges) if not idx_remove[ii]] + candidate_max_dbs = [] # Filter all candidates to the ranges that have a substantial right-side skew - ranges, reject_idxs = filter_candidates_to_ranges(stft_db, candidates, output_path=debug_path) + + ranges, reject_idxs = filter_candidates_to_ranges( + stft_db, + candidates, + area_percent=0.01, + min_band_idx=min_band_idx, + output_path=debug_path, + fast_mode=fast_mode, + quiet=quiet, + ) # Add in user-specified annotations to ranges if annotations: @@ -1384,172 +1544,218 @@ def compute_wrapper( # Plot the chirp candidates (will no-op if output_path is None) plot_chirp_candidates(stft_db, candidate_max_dbs, ranges, reject_idxs, output_path=debug_path) - # Tighten the ranges by looking for substantial right-side skew (use stride for a smaller sampling window) - ranges = tighten_ranges(stft_db, ranges, stride, duration, output_path=debug_path) - - # Extract chirp metrics and metadata - segments = { - 'stft_db': [], - 'waveplot': [], - 'costs': [], - 'canvas': [], - } - metas = [] - for index, (start, stop) in tqdm.tqdm(list(enumerate(ranges))): - segment = stft_db[:, start:stop] - - # Step 0.1 - Debugging setup and find peak amplitude (will return None if disabled) - canvas = create_contour_debug_canvas(segment, index, output_path=debug_path) - - # Step 0.2 - Find the location(s) of peak amplitude - max_locations = find_max_locations(segment) + if fast_mode: + # Apply reduced processing without segment refinement or per-call metadata calculation + + segments = {'stft_db': []} + # Remove a fraction of the window length when not doing call segmentation + crop_length_l = max(0, int(round(0.15 * window - 1))) + crop_length_r = max(0, int(round(0.45 * window - 1))) + metas = [] + for start, stop in ranges: + if start > 0 and stop < stft_db.shape[1]: + segments['stft_db'].append(stft_db[:, start + crop_length_l : stop - crop_length_r]) + elif start > 0: + # handle cases where candidate butts up against data edges + segments['stft_db'].append(stft_db[:, start + crop_length_l : stop]) + else: + segments['stft_db'].append(stft_db[:, start : stop - crop_length_r]) + # Add basic metadata + metadata = { + 'segment start.ms': (start + crop_length_l) * x_step_ms, + 'segment end.ms': (stop - crop_length_r) * x_step_ms, + 'segment duration.ms': (stop - crop_length_r - start - crop_length_l) * x_step_ms, + } + # Normalize values + for key, value in list(metadata.items()): + if key.endswith('.ms'): + metadata[key] = round(float(value), 3) + metas.append(metadata) - # Step 1 - Scale with PDF - segment, peak_db, peak_db_std = scale_pdf_contour(segment, index, output_path=debug_path) - if None in {peak_db, peak_db_std}: - continue - - # Step 2 - Apply median filtering to contour - segment = filter_contour(segment, index, output_path=debug_path) - - # Step 3 - Apply Morphology Open to contour - segment = morph_open_contour(segment, index, output_path=debug_path) + else: - # Step 4 - Normalize contour - segment = normalize_contour(segment, index, output_path=debug_path) + # Tighten the ranges by looking for substantial right-side skew (use stride for a smaller sampling window) + ranges = tighten_ranges( + stft_db, ranges, stride, duration, output_path=debug_path, quiet=quiet + ) - # # Step 5 (OLD) - Threshold contour - # segment, med_db, std_db, peak_db = threshold_contour(segment, index, output_path=debug_path) + # Extract chirp metrics and metadata + segments = { + 'stft_db': [], + 'waveplot': [], + 'costs': [], + 'canvas': [], + } + metas = [] + for index, (start, stop) in tqdm.tqdm(list(enumerate(ranges)), disable=quiet): + segment = stft_db[:, start:stop] - # Step 5 - Find primary contour that contains max amplitude - # (To use a local instead of global threshold, remove the threshold argument here) - segmentmask, peak, segment_threshold = find_contour_and_peak( - segment, - index, - max_locations, - peak_db, - peak_db_std, - output_path=debug_path, - threshold=global_threshold, - ) + # Step 0.1 - Debugging setup and find peak amplitude (will return None if disabled) + canvas = create_contour_debug_canvas(segment, index, output_path=debug_path) - if peak is None: - continue + # Step 1 - Scale with PDF + segment, peak_db, peak_db_std = scale_pdf_contour( + segment, index, min_band_idx=min_band_idx, output_path=debug_path + ) + if None in {peak_db, peak_db_std}: + continue - # Step 6 - Create final segmentmask - segmentmask = refine_segmentmask(segmentmask, index, output_path=debug_path) + # Step 2 - Apply median filtering to contour + segment = filter_contour(segment, index, output_path=debug_path) - # # Step 6 (OLD) - Find the contour with the (most) max amplitude location(s) - # valid, segmentmask, peak = find_contour_connected_components(segment, index, max_locations, output_path=debug_path) - # # Step 6 (OLD) - Refine contour by removing any harmonic or echo - # segmentmask, peak = refine_contour(segment_, index, max_locations, segmentmask, peak, output_path=debug_path) + # Step 3 - Apply Morphology Open to contour + segment = morph_open_contour(segment, index, output_path=debug_path) - # Step 7 - Calculate the first order harmonic and echo region - harmonic = find_harmonic(segmentmask, index, freq_offset, output_path=debug_path) - echo = find_echo(segmentmask, index, output_path=debug_path) + # Step 4 - Normalize contour + segment = normalize_contour(segment, index, output_path=debug_path) - original = stft_db[:, start:stop] - harmonic_flag, hamonic_peak, echo_flag, echo_peak = calculate_harmonic_and_echo_flags( - original, index, segmentmask, harmonic, echo, canvas, output_path=debug_path - ) + # Step 4.1 - Find the location(s) of peak amplitude + max_locations = find_max_locations(segment) - # Remove harmonic and echo from segmentation - segment = remove_harmonic_and_echo( - segment, index, harmonic, echo, global_threshold, output_path=debug_path - ) + # Step 5 - Find primary contour that contains max amplitude + # (To use a local instead of global threshold, remove the threshold argument here) + segmentmask, peak, segment_threshold = find_contour_and_peak( + segment, + index, + max_locations, + peak_db, + peak_db_std, + output_path=debug_path, + threshold=global_threshold, + ) - # Step 8 - Calculate the A* cost grid and bat call start/end points - costs, grid, call_begin, call_end, boundary = calculate_astar_grid_and_endpoints( - segment, index, segmentmask, peak, canvas, output_path=debug_path - ) - top, bottom, left, right = boundary + if peak is None: + continue - # Skip chirp if the extracted path covers a small duration or bandwidth - bandwidth, duration_, significant = significant_contour_path( - call_begin, call_end, y_step_freq, x_step_ms - ) - if not significant: - continue + # Step 6 - Create final segmentmask + segmentmask = refine_segmentmask(segmentmask, index, output_path=debug_path) - # Step 9 - Extract optimal path from start to end using the cost grid - path = extract_contour_path( - grid, call_begin, call_end, canvas, index, output_path=debug_path - ) + # Step 7 - Calculate the first order harmonic and echo region + harmonic = find_harmonic(segmentmask, index, freq_offset, output_path=debug_path) + echo = find_echo(segmentmask, index, output_path=debug_path) - # Step 10 - Extract contour keypoints - path_smoothed, (knee, fc, heel), slopes = extract_contour_keypoints( - path, canvas, index, peak, output_path=debug_path - ) + original = stft_db[:, start:stop] + harmonic_flag, hamonic_peak, echo_flag, echo_peak = calculate_harmonic_and_echo_flags( + original, index, segmentmask, harmonic, echo, canvas, output_path=debug_path + ) - # Step 11 - Collect chirp metadata - metadata = { - 'curve.(hz,ms)': [ - ( - bands[y], - (start + x) * x_step_ms, + # Remove harmonic and echo from segmentation + if mask_secondary_effects: + segment = remove_harmonic_and_echo( + segment, index, harmonic, echo, global_threshold, output_path=debug_path ) - for y, x in path_smoothed - ], - 'start.ms': (start + left) * x_step_ms, - 'end.ms': (start + right) * x_step_ms, - 'duration.ms': (right - left) * x_step_ms, - 'threshold.amp': int(round(255.0 * (segment_threshold / np.iinfo(stft_db.dtype).max))), - 'peak f.ms': (start + peak[1]) * x_step_ms, - 'fc.ms': (start + bands[fc[1]]) * x_step_ms, - 'hi fc:knee.ms': (start + bands[knee[1]]) * x_step_ms, - 'lo fc:heel.ms': (start + bands[heel[1]]) * x_step_ms, - 'bandwidth.hz': bandwidth, - 'hi f.hz': bands[top], - 'lo f.hz': bands[bottom], - 'peak f.hz': bands[peak[0]], - 'fc.hz': bands[fc[0]], - 'hi fc:knee.hz': bands[knee[0]], - 'lo fc:heel.hz': bands[heel[0]], - 'harmonic.flag': harmonic_flag, - 'harmonic peak f.ms': (start + hamonic_peak[1]) * x_step_ms if harmonic_flag else None, - 'harmonic peak f.hz': bands[hamonic_peak[0]] if harmonic_flag else None, - 'echo.flag': echo_flag, - 'echo peak f.ms': (start + echo_peak[1]) * x_step_ms if echo_flag else None, - 'echo peak f.hz': bands[echo_peak[0]] if echo_flag else None, - } - metadata.update(slopes) - # Normalize values - for key, value in list(metadata.items()): - if value is None: + # Step 8 - Calculate the A* cost grid and bat call start/end points + costs, grid, call_begin, call_end, boundary = calculate_astar_grid_and_endpoints( + segment, index, segmentmask, peak, canvas, output_path=debug_path + ) + top, bottom, left, right = boundary + + # Skip chirp if the extracted path covers a small duration or bandwidth + min_bandwidth_khz = 1e3 + min_duration_ms = 2.0 + bandwidth, duration_, significant = significant_contour_path( + call_begin, + call_end, + y_step_freq, + x_step_ms, + min_bandwidth_khz=min_bandwidth_khz, + min_duration_ms=min_duration_ms, + ) + if not significant: continue - if key.endswith('.ms'): - metadata[key] = round(float(value), 3) - if key.endswith('.hz'): - metadata[key] = int(round(value)) - if key.endswith('.flag'): - metadata[key] = bool(value) - if key.endswith('.y_px/x_px'): - key_ = key.replace('.y_px/x_px', '.khz/ms') - metadata[key_] = round(float(value * ((y_step_freq / 1000.0) / x_step_ms)), 3) - metadata.pop(key) - if key.endswith('.(hz,ms)'): - metadata[key] = [ - ( - int(round(val1)), - round(float(val2), 3), - ) - for val1, val2 in value - ] - metas.append(metadata) + # Step 9 - Extract optimal path from start to end using the cost grid + path = extract_contour_path( + grid, call_begin, call_end, canvas, index, output_path=debug_path + ) - # Trim segment around the bat call with a small buffer - buffer_ms = 1.0 - buffer_pix = int(round(buffer_ms / x_step_ms)) - trim_begin = max(0, min(segment.shape[1], call_begin[1] - buffer_pix)) - trim_end = max(0, min(segment.shape[1], call_end[1] + buffer_pix)) + # Step 10 - Extract contour keypoints + path_smoothed, (knee, fc, heel), slopes = extract_contour_keypoints( + path, canvas, index, peak, output_path=debug_path + ) - segments['stft_db'].append(stft_db[:, start + trim_begin : start + trim_end]) - segments['waveplot'].append(waveplot[:, start + trim_begin : start + trim_end]) - segments['costs'].append(costs[:, trim_begin:trim_end]) - if debug_path: - segments['canvas'].append(canvas[:, trim_begin:trim_end]) + # Step 11 - Collect chirp metadata + metadata = { + 'curve.(hz,ms)': [ + ( + bands[y], + (start + x) * x_step_ms, + ) + for y, x in path_smoothed + ], + 'contour start.ms': (start + left) * x_step_ms, + 'contour end.ms': (start + right) * x_step_ms, + 'contour duration.ms': (right - left) * x_step_ms, + 'threshold.amp': int( + round(255.0 * (segment_threshold / np.iinfo(stft_db.dtype).max)) + ), + 'peak f.ms': (start + peak[1]) * x_step_ms, + 'fc.ms': (start + fc[1]) * x_step_ms, + 'hi fc:knee.ms': (start + knee[1]) * x_step_ms, + 'lo fc:heel.ms': (start + heel[1]) * x_step_ms, + 'bandwidth.hz': bandwidth, + 'hi f.hz': bands[top], + 'lo f.hz': bands[bottom], + 'peak f.hz': bands[peak[0]], + 'fc.hz': bands[fc[0]], + 'hi fc:knee.hz': bands[knee[0]], + 'lo fc:heel.hz': bands[heel[0]], + 'harmonic.flag': harmonic_flag, + 'harmonic peak f.ms': ( + (start + hamonic_peak[1]) * x_step_ms if harmonic_flag else None + ), + 'harmonic peak f.hz': bands[hamonic_peak[0]] if harmonic_flag else None, + 'echo.flag': echo_flag, + 'echo peak f.ms': (start + echo_peak[1]) * x_step_ms if echo_flag else None, + 'echo peak f.hz': bands[echo_peak[0]] if echo_flag else None, + } + metadata.update(slopes) + + # Trim segment around the bat call with a small buffer + buffer_ms = 1.0 + buffer_pix = int(round(buffer_ms / x_step_ms)) + trim_begin = max(0, min(segment.shape[1], call_begin[1] - buffer_pix)) + trim_end = max(0, min(segment.shape[1], call_end[1] + buffer_pix)) + + segments['stft_db'].append(stft_db[:, start + trim_begin : start + trim_end]) + segments['waveplot'].append(waveplot[:, start + trim_begin : start + trim_end]) + segments['costs'].append(costs[:, trim_begin:trim_end]) + if debug_path: + segments['canvas'].append(canvas[:, trim_begin:trim_end]) + + # Update metadata with segment start and stop + start_stop = { + 'segment start.ms': (start + trim_begin) * x_step_ms, + 'segment end.ms': (start + trim_end) * x_step_ms, + 'segment duration.ms': (trim_end - trim_begin) * x_step_ms, + } + metadata.update(start_stop) + + # Normalize values + for key, value in list(metadata.items()): + if value is None: + continue + if key.endswith('.ms'): + metadata[key] = round(float(value), 3) + if key.endswith('.hz'): + metadata[key] = int(round(value)) + if key.endswith('.flag'): + metadata[key] = bool(value) + if key.endswith('.y_px/x_px'): + key_ = key.replace('.y_px/x_px', '.khz/ms') + metadata[key_] = round(float(value * ((y_step_freq / 1000.0) / x_step_ms)), 3) + metadata.pop(key) + if key.endswith('.(hz,ms)'): + metadata[key] = [ + ( + int(round(val1)), + round(float(val2), 3), + ) + for val1, val2 in value + ] + + metas.append(metadata) # Concatenate extracted, trimmed segments and other matrices for key in list(segments.keys()): @@ -1617,9 +1823,12 @@ def compute_wrapper( output_paths = [] compressed_paths = [] - datas = [ - (output_paths, 'jpg', stft_db), - ] + if not fast_mode: + datas = [ + (output_paths, 'jpg', stft_db), + ] + else: + datas = [] if 'stft_db' in segments: datas += [ (compressed_paths, 'compressed.jpg', segments['stft_db']), @@ -1642,7 +1851,7 @@ def compute_wrapper( for index, chunk in enumerate(chunks): if chunk.shape[1] == 0: continue - output_path = join(output_folder, f'{base}.{index + 1:02d}of{total:02d}.{tag}') + output_path = f'{out_file_stem}.{index + 1:02d}of{total:02d}.{tag}' cv2.imwrite(output_path, chunk, [cv2.IMWRITE_JPEG_QUALITY, 80]) accumulator.append(output_path) @@ -1678,8 +1887,8 @@ def compute_wrapper( 'height.px': segments['stft_db'].shape[0], } - metadata_path = join(output_folder, f'{base}.metadata.json') + metadata_path = f'{out_file_stem}.metadata.json' with open(metadata_path, 'w') as metafile: json.dump(metadata, metafile, indent=4) - return output_paths, metadata_path, metadata + return output_paths, compressed_paths, metadata_path, metadata diff --git a/tests/test_spectrogram.py b/tests/test_spectrogram.py index 47fc1dc..5120082 100644 --- a/tests/test_spectrogram.py +++ b/tests/test_spectrogram.py @@ -6,4 +6,6 @@ def test_spectrogram_compute(): wav_filepath = abspath(join('examples', 'example2.wav')) output_folder = './output' - output_paths, metadata_path, metadata = compute(wav_filepath, output_folder=output_folder) + output_paths, compressed_paths, metadata_path, metadata = compute( + wav_filepath, output_folder=output_folder + )