Skip to content

Commit bce4552

Browse files
committed
First commit: demucs.cpp
0 parents  commit bce4552

32 files changed

+7001
-0
lines changed

.clang-format

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
UseTab: Never
2+
IndentWidth: 4
3+
BreakBeforeBraces: Allman
4+
AllowShortIfStatementsOnASingleLine: false
5+
IndentCaseLabels: false
6+
ColumnLimit: 80

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
/build/
2+
*.bin
3+
*.wav
4+
!/test/data/*.wav
5+
*.json
6+
/__pycache__
7+
*.pyc

CMakeLists.txt

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# cmake file to compile src/
2+
# link against included submodules libnyquist
3+
4+
cmake_minimum_required(VERSION 3.0)
5+
6+
if(NOT CMAKE_BUILD_TYPE)
7+
set(CMAKE_BUILD_TYPE Release)
8+
endif()
9+
10+
set(CMAKE_CXX_FLAGS "-Wall -Wextra")
11+
set(CMAKE_CXX_FLAGS_DEBUG "-g -DEIGEN_FAST_MATH=0 -O0")
12+
13+
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -march=native -fno-unsafe-math-optimizations -fassociative-math -freciprocal-math -fno-signed-zeros")
14+
15+
# define a macro NDEBUG for Eigen3 release builds
16+
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG")
17+
18+
# set EIGEN_USE_BLAS to 1 and link to OpenBLAS
19+
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DEIGEN_USE_BLAS -DEIGEN_USE_LAPACKE")
20+
21+
project(demucs.cpp)
22+
enable_testing()
23+
24+
# set C++ standard to C++17
25+
set(CMAKE_CXX_STANDARD 17)
26+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
27+
28+
# add openmp support
29+
find_package(OpenMP REQUIRED)
30+
if(OPENMP_FOUND)
31+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
32+
include_directories(${OpenMP_CXX_INCLUDE_DIRS})
33+
endif()
34+
35+
# compile vendored submodule libnyquist
36+
set(LIBNYQUIST_BUILD_EXAMPLE OFF CACHE BOOL "Disable libnyquist example")
37+
add_subdirectory(vendor/libnyquist)
38+
39+
# add library Eigen3
40+
include_directories(vendor/eigen)
41+
42+
# add OpenBLAS for blas + lapack
43+
find_package(BLAS REQUIRED)
44+
find_package(LAPACK REQUIRED)
45+
46+
# include vendor submodules libnyquist
47+
include_directories(vendor/libnyquist/include)
48+
49+
# include src/ as include directory
50+
include_directories(src)
51+
52+
# include src/*.cpp and src/*.c as source files
53+
file(GLOB SOURCES "src/*.cpp")
54+
55+
# compile library, link against libnyquist
56+
add_library(demucs.cpp.lib SHARED ${SOURCES})
57+
target_link_libraries(demucs.cpp.lib libnyquist ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES} lapacke)
58+
if(OPENMP_FOUND)
59+
target_link_libraries(demucs.cpp.lib ${OpenMP_CXX_LIBRARIES})
60+
endif()
61+
62+
file(GLOB SOURCES_TO_LINT "src/*.cpp" "src/*.hpp" "demucs.cpp" "test/*.cpp")
63+
64+
# add target to run standard lints and formatters
65+
add_custom_target(lint
66+
COMMAND clang-format -i ${SOURCES_TO_LINT} --style=file
67+
# add clang-tidy command
68+
# add include dirs to clang-tidy
69+
COMMAND cppcheck --enable=all --suppress=missingIncludeSystem ${SOURCES_TO_LINT} --std=c++17
70+
COMMAND scan-build -o ${CMAKE_BINARY_DIR}/scan-build-report make -C ${CMAKE_BINARY_DIR}
71+
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
72+
)
73+
74+
# add target to compile demucs.cpp, the main driver program for demucs.cpp
75+
add_executable(demucs.cpp.main demucs.cpp)
76+
# link it against demucs.cpp.lib
77+
target_link_libraries(demucs.cpp.main demucs.cpp.lib)
78+
79+
# add target to run cpp tests in test/ directory with gtest
80+
81+
# include test/*.cpp as test files
82+
file(GLOB TEST_SOURCES "test/*.cpp")
83+
84+
add_executable(demucs.cpp.test ${TEST_SOURCES})
85+
target_link_libraries(demucs.cpp.test demucs.cpp.lib gtest gtest_main libnyquist)
86+
add_test(NAME tests COMMAND demucs.cpp.test)

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Sevag H
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# demucs.cpp
2+
3+
Demucs v4 hybrid transformer model reimplemented in C++ with Eigen3
4+
5+
Track 'Zeno - Signs' from MUSDB18-HQ test set
6+
7+
PyTorch CLI inference (output of `demucs /path/to/track` from [this commit of demucs v4](https://github.com/facebookresearch/demucs@2496b8f7f12b01c8dd0187c040000c46e175b44d)):
8+
```
9+
vocals ==> SDR: 8.264 SIR: 18.353 ISR: 15.794 SAR: 8.303
10+
drums ==> SDR: 10.111 SIR: 18.503 ISR: 17.089 SAR: 10.746
11+
bass ==> SDR: 4.222 SIR: 12.615 ISR: 6.973 SAR: 2.974
12+
other ==> SDR: 7.397 SIR: 11.317 ISR: 14.303 SAR: 8.137
13+
```
14+
PyTorch custom inference in [my script](./scripts/demucs_pytorch_inference.py):
15+
```
16+
vocals ==> SDR: 8.339 SIR: 18.274 ISR: 15.835 SAR: 8.354
17+
drums ==> SDR: 10.058 SIR: 18.598 ISR: 17.023 SAR: 10.812
18+
bass ==> SDR: 3.926 SIR: 12.414 ISR: 6.941 SAR: 3.202
19+
other ==> SDR: 7.421 SIR: 11.289 ISR: 14.241 SAR: 8.179
20+
```
21+
CPP inference (this codebase):
22+
```
23+
vocals ==> SDR: 8.339 SIR: 18.276 ISR: 15.836 SAR: 8.346
24+
drums ==> SDR: 10.058 SIR: 18.596 ISR: 17.019 SAR: 10.810
25+
bass ==> SDR: 3.919 SIR: 12.436 ISR: 6.931 SAR: 3.182
26+
other ==> SDR: 7.421 SIR: 11.286 ISR: 14.252 SAR: 8.183
27+
```
28+
29+
*n.b.* for testing purposes in this repo, the random shift in the beginning of the song is fixed to 1337 in both PyTorch and C++.
30+
31+
## Build and run
32+
33+
Out-of-source build with CMake:
34+
```
35+
$ mkdir -p build && cd build && cmake -DCMAKE_BUILD_TYPE=Release ..
36+
$ make
37+
```
38+
39+
The `Release` build type adds optimization flags (Ofast etc.), without which this project is unusably slow.
40+
41+
Run:
42+
```
43+
$ ./demucs.cpp.main ../ggml-demucs/ggml-model-htdemucs-f16.bin ../test/data/gspi_stereo.wav ./demucs-out-cpp/
44+
```
45+
46+
## Hack
47+
48+
* make lint
49+
* Valgrind memory error test: `valgrind --leak-check=full --show-leak-kinds=all --track-origins=yes --verbose ./demucs.cpp.main ../ggml-demucs/ggml-model-htdemucs-f16.bin ../test/data/gspi_stereo.wav ./demucs-out-cpp/`
50+
*

demucs.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include "dsp.hpp"
2+
#include "model.hpp"
3+
#include "tensor.hpp"
4+
#include <Eigen/Core>
5+
#include <Eigen/Dense>
6+
#include <cassert>
7+
#include <filesystem>
8+
#include <iostream>
9+
#include <sstream>
10+
#include <string>
11+
#include <thread>
12+
#include <unsupported/Eigen/FFT>
13+
#include <vector>
14+
15+
using namespace demucscpp;
16+
17+
int main(int argc, const char **argv)
18+
{
19+
if (argc != 4)
20+
{
21+
std::cerr << "Usage: " << argv[0]
22+
<< " <model file> <wav file> <out dir>" << std::endl;
23+
exit(1);
24+
}
25+
26+
// enable openmp parallelization for Eigen
27+
// init parallelism for eigen
28+
Eigen::initParallel();
29+
30+
// set eigen nb threads to physical cores minus 1
31+
// discover number of physical cores through C++ stdlib
32+
// https://stackoverflow.com/questions/150355/programmatically-find-the-number-of-cores-on-a-machine
33+
int nb_cores = std::thread::hardware_concurrency();
34+
std::cout << "Number of physical cores: " << nb_cores << std::endl;
35+
Eigen::setNbThreads(nb_cores - 1);
36+
37+
std::cout << "demucs.cpp Main driver program" << std::endl;
38+
39+
// load model passed as argument
40+
std::string model_file = argv[1];
41+
42+
// load audio passed as argument
43+
std::string wav_file = argv[2];
44+
45+
// output dir passed as argument
46+
std::string out_dir = argv[3];
47+
48+
Eigen::MatrixXf audio = load_audio(wav_file);
49+
Eigen::Tensor3dXf out_targets;
50+
51+
std::cout << "Using 4s model" << std::endl;
52+
53+
// initialize a struct demucs_model
54+
struct demucs_model_4s model
55+
{
56+
};
57+
58+
auto ret = load_demucs_model_4s(model_file, &model);
59+
std::cout << "demucs_model_load returned " << (ret ? "true" : "false")
60+
<< std::endl;
61+
if (!ret)
62+
{
63+
std::cerr << "Error loading model" << std::endl;
64+
exit(1);
65+
}
66+
67+
std::cout << "Starting demucs inference" << std::endl;
68+
69+
// create 4 audio matrix same size, to hold output
70+
Eigen::Tensor3dXf audio_targets =
71+
demucscpp::demucs_inference_4s(model, audio);
72+
std::cout << "returned!" << std::endl;
73+
74+
out_targets = audio_targets;
75+
76+
for (int target = 0; target < 4; ++target)
77+
{
78+
// now write the 4 audio waveforms to files in the output dir
79+
// using libnyquist
80+
// join out_dir with "/target_0.wav"
81+
// using std::filesystem::path;
82+
83+
std::filesystem::path p = out_dir;
84+
// make sure the directory exists
85+
std::filesystem::create_directories(p);
86+
87+
auto p_target = p / "target_0.wav";
88+
// generate p_target = p / "target_{target}.wav"
89+
p_target.replace_filename("target_" + std::to_string(target) + ".wav");
90+
91+
std::cout << "Writing wav file " << p_target << std::endl;
92+
93+
Eigen::MatrixXf target_waveform(2, audio.cols());
94+
95+
// copy the input stereo wav file into all 4 targets
96+
for (int channel = 0; channel < 2; ++channel)
97+
{
98+
for (int sample = 0; sample < audio.cols(); ++sample)
99+
{
100+
target_waveform(channel, sample) =
101+
out_targets(target, channel, sample);
102+
}
103+
}
104+
105+
demucscppdebug::debug_matrix_xf(target_waveform,
106+
"target_waveform for target " +
107+
std::to_string(target));
108+
109+
demucscpp::write_audio_file(target_waveform, p_target);
110+
}
111+
}

scripts/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)