19 lines
1000 B
C++
19 lines
1000 B
C++
|
|
#include "ca.h"
|
||
|
|
#include "psa.h"
|
||
|
|
#include "syncbn.h"
|
||
|
|
|
||
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||
|
|
m.def("ca_forward", &ca_forward, "ca_forward");
|
||
|
|
m.def("ca_backward", &ca_backward, "ca_backward");
|
||
|
|
m.def("ca_map_forward", &ca_map_forward, "ca_map_forward");
|
||
|
|
m.def("ca_map_backward", &ca_map_backward, "ca_map_backward");
|
||
|
|
m.def("psa_forward", &psa_forward, "psa_forward");
|
||
|
|
m.def("psa_backward", &psa_backward, "psa_backward");
|
||
|
|
m.def("batchnorm_forward", &batchnorm_forward, "batchnorm_forward");
|
||
|
|
m.def("inp_batchnorm_forward", &inp_batchnorm_forward, "inp_batchnorm_forward");
|
||
|
|
m.def("batchnorm_backward", &batchnorm_backward, "batchnorm_backward");
|
||
|
|
m.def("inp_batchnorm_backward", &inp_batchnorm_backward, "inp_batchnorm_backward");
|
||
|
|
m.def("expectation_forward", &expectation_forward, "expectation_forward");
|
||
|
|
m.def("expectation_backward", &expectation_backward, "expectation_backward");
|
||
|
|
m.def("inp_expectation_backward", &inp_expectation_backward, "inp_expectation_backward");
|
||
|
|
}
|