cython wrapper for acados (#22784)

* cython wrapper for acados

* fix building

* sconscript cleanup

* no cython numpy

* cleanup

* upgrade build script

* try without slices

* new acados commit

* c3 update acados libs

* c2 libs

* make faster

* undo profiling

* fix build

* somewhat faster

* tryout cost_set_slice

* Revert "tryout cost_set_slice"

This reverts commit d358d93a133270e4edab9e7c07ffb6f577c52bd6.

* cleanup

* undo t_renderer change

Co-authored-by: Comma Device <device@comma.ai>
old-commit-hash: 89d0a52d16872c403c69426ab32e5788a41ee2ec
This commit is contained in:
Joost Wooning
2021-11-12 17:09:08 +01:00
committed by GitHub
parent 1bcba44589
commit 9042f18afd
60 changed files with 503 additions and 187 deletions

View File

@@ -68,6 +68,7 @@ lenv = {
"PYTHONPATH": Dir("#").abspath + ":" + Dir("#pyextra/").abspath,
"ACADOS_SOURCE_DIR": Dir("#third_party/acados/acados").abspath,
"ACADOS_PYTHON_INTERFACE_PATH": Dir("#pyextra/acados_template").abspath,
"TERA_PATH": Dir("#").abspath + f"/third_party/acados/{arch}/t_renderer",
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,4 +1,4 @@
Import('env', 'arch')
Import('env', 'envCython', 'arch', 'common')
gen = "c_generated_code"
@@ -33,6 +33,7 @@ generated_files = [
f'{gen}/main_lat.c',
f'{gen}/acados_solver_lat.h',
f'{gen}/acados_solver.pxd',
f'{gen}/lat_model/lat_expl_vde_adj.c',
@@ -53,6 +54,24 @@ lenv["CFLAGS"].append("-DACADOS_WITH_QPOASES")
lenv["CXXFLAGS"].append("-DACADOS_WITH_QPOASES")
lenv["CCFLAGS"].append("-Wno-unused")
lenv["LINKFLAGS"].append("-Wl,--disable-new-dtags")
lenv.SharedLibrary(f"{gen}/acados_ocp_solver_lat",
build_files,
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
lib_solver = lenv.SharedLibrary(f"{gen}/acados_ocp_solver_lat",
build_files,
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
# generate cython stuff
acados_ocp_solver_pyx = File("#pyextra/acados_template/acados_ocp_solver_pyx.pyx")
acados_ocp_solver_common = File("#pyextra/acados_template/acados_solver_common.pxd")
libacados_ocp_solver_pxd = File(f'{gen}/acados_solver.pxd')
libacados_ocp_solver_c = File(f'{gen}/acados_ocp_solver_pyx.c')
lenv2 = envCython.Clone()
lenv2["LINKFLAGS"] += [lib_solver[0].get_labspath()]
lenv2.Command(libacados_ocp_solver_c,
[acados_ocp_solver_pyx, acados_ocp_solver_common, libacados_ocp_solver_pxd],
f'cython' + \
f' -o {libacados_ocp_solver_c.get_labspath()}' + \
f' -I {libacados_ocp_solver_pxd.get_dir().get_labspath()}' + \
f' -I {acados_ocp_solver_common.get_dir().get_labspath()}' + \
f' {acados_ocp_solver_pyx.get_labspath()}')
lib_cython = lenv2.Program(f'{gen}/acados_ocp_solver_pyx.so', [libacados_ocp_solver_c])
lenv2.Depends(lib_cython, lib_solver)

View File

@@ -5,8 +5,12 @@ import numpy as np
from casadi import SX, vertcat, sin, cos
from selfdrive.controls.lib.drive_helpers import LAT_MPC_N as N
from selfdrive.controls.lib.drive_helpers import T_IDXS
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
if __name__ == '__main__': # generating code
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
else:
# from pyextra.acados_template import AcadosOcpSolverFast
from selfdrive.controls.lib.lateral_mpc_lib.c_generated_code.acados_ocp_solver_pyx import AcadosOcpSolverFast # pylint: disable=no-name-in-module, import-error
LAT_MPC_DIR = os.path.dirname(os.path.abspath(__file__))
EXPORT_DIR = os.path.join(LAT_MPC_DIR, "c_generated_code")
@@ -110,17 +114,16 @@ def gen_lat_mpc_solver():
class LateralMpc():
def __init__(self, x0=np.zeros(X_DIM)):
self.solver = AcadosOcpSolver('lat', N, EXPORT_DIR)
self.solver = AcadosOcpSolverFast('lat', N, EXPORT_DIR)
self.reset(x0)
def reset(self, x0=np.zeros(X_DIM)):
self.x_sol = np.zeros((N+1, X_DIM))
self.u_sol = np.zeros((N, 1))
self.yref = np.zeros((N+1, 3))
self.solver.cost_set_slice(0, N, "yref", self.yref[:N])
for i in range(N):
self.solver.cost_set(i, "yref", self.yref[i])
self.solver.cost_set(N, "yref", self.yref[N][:2])
W = np.eye(3)
self.Ws = np.tile(W[None], reps=(N,1,1))
# Somehow needed for stable init
for i in range(N+1):
@@ -132,12 +135,11 @@ class LateralMpc():
self.cost = 0
def set_weights(self, path_weight, heading_weight, steer_rate_weight):
self.Ws[:,0,0] = path_weight
self.Ws[:,1,1] = heading_weight
self.Ws[:,2,2] = steer_rate_weight
self.solver.cost_set_slice(0, N, 'W', self.Ws, api='old')
W = np.asfortranarray(np.diag([path_weight, heading_weight, steer_rate_weight]))
for i in range(N):
self.solver.cost_set(i, 'W', W)
#TODO hacky weights to keep behavior the same
self.solver.cost_set(N, 'W', (3/20.)*self.Ws[0,:2,:2])
self.solver.cost_set(N, 'W', (3/20.)*W[:2,:2])
def run(self, x0, v_ego, car_rotation_radius, y_pts, heading_pts):
x0_cp = np.copy(x0)
@@ -145,12 +147,15 @@ class LateralMpc():
self.solver.constraints_set(0, "ubx", x0_cp)
self.yref[:,0] = y_pts
self.yref[:,1] = heading_pts*(v_ego+5.0)
self.solver.cost_set_slice(0, N, "yref", self.yref[:N])
for i in range(N):
self.solver.cost_set(i, "yref", self.yref[i])
self.solver.cost_set(N, "yref", self.yref[N][:2])
self.solution_status = self.solver.solve()
self.solver.fill_in_slice(0, N+1, 'x', self.x_sol)
self.solver.fill_in_slice(0, N, 'u', self.u_sol)
for i in range(N+1):
self.x_sol[i] = self.solver.get(i, 'x')
for i in range(N):
self.u_sol[i] = self.solver.get(i, 'u')
self.cost = self.solver.get_cost()

View File

@@ -1,4 +1,4 @@
Import('env', 'arch')
Import('env', 'envCython', 'arch', 'common')
gen = "c_generated_code"
@@ -41,6 +41,7 @@ generated_files = [
f'{gen}/main_long.c',
f'{gen}/acados_solver_long.h',
f'{gen}/acados_solver.pxd',
f'{gen}/long_model/long_expl_vde_adj.c',
@@ -63,6 +64,24 @@ lenv["CFLAGS"].append("-DACADOS_WITH_QPOASES")
lenv["CXXFLAGS"].append("-DACADOS_WITH_QPOASES")
lenv["CCFLAGS"].append("-Wno-unused")
lenv["LINKFLAGS"].append("-Wl,--disable-new-dtags")
lenv.SharedLibrary(f"{gen}/acados_ocp_solver_long",
build_files,
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
lib_solver = lenv.SharedLibrary(f"{gen}/acados_ocp_solver_long",
build_files,
LIBS=['m', 'acados', 'hpipm', 'blasfeo', 'qpOASES_e'])
# generate cython stuff
acados_ocp_solver_pyx = File("#pyextra/acados_template/acados_ocp_solver_pyx.pyx")
acados_ocp_solver_common = File("#pyextra/acados_template/acados_solver_common.pxd")
libacados_ocp_solver_pxd = File(f'{gen}/acados_solver.pxd')
libacados_ocp_solver_c = File(f'{gen}/acados_ocp_solver_pyx.c')
lenv2 = envCython.Clone()
lenv2["LINKFLAGS"] += [lib_solver[0].get_labspath()]
lenv2.Command(libacados_ocp_solver_c,
[acados_ocp_solver_pyx, acados_ocp_solver_common, libacados_ocp_solver_pxd],
f'cython' + \
f' -o {libacados_ocp_solver_c.get_labspath()}' + \
f' -I {libacados_ocp_solver_pxd.get_dir().get_labspath()}' + \
f' -I {acados_ocp_solver_common.get_dir().get_labspath()}' + \
f' {acados_ocp_solver_pyx.get_labspath()}')
lib_cython = lenv2.Program(f'{gen}/acados_ocp_solver_pyx.so', [libacados_ocp_solver_c])
lenv2.Depends(lib_cython, lib_solver)

View File

@@ -8,7 +8,12 @@ from selfdrive.swaglog import cloudlog
from selfdrive.modeld.constants import index_function
from selfdrive.controls.lib.radar_helpers import _LEAD_ACCEL_TAU
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
if __name__ == '__main__': # generating code
from pyextra.acados_template import AcadosModel, AcadosOcp, AcadosOcpSolver
else:
# from pyextra.acados_template import AcadosOcpSolver as AcadosOcpSolverFast
from selfdrive.controls.lib.longitudinal_mpc_lib.c_generated_code.acados_ocp_solver_pyx import AcadosOcpSolverFast # pylint: disable=no-name-in-module, import-error
from casadi import SX, vertcat
LONG_MPC_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -190,13 +195,14 @@ class LongitudinalMpc():
self.source = SOURCES[2]
def reset(self):
self.solver = AcadosOcpSolver('long', N, EXPORT_DIR)
self.solver = AcadosOcpSolverFast('long', N, EXPORT_DIR)
self.v_solution = [0.0 for i in range(N+1)]
self.a_solution = [0.0 for i in range(N+1)]
self.j_solution = [0.0 for i in range(N)]
self.yref = np.zeros((N+1, COST_DIM))
self.solver.cost_set_slice(0, N, "yref", self.yref[:N])
self.solver.set(N, "yref", self.yref[N][:COST_E_DIM])
for i in range(N):
self.solver.cost_set(i, "yref", self.yref[i])
self.solver.cost_set(N, "yref", self.yref[N][:COST_E_DIM])
self.x_sol = np.zeros((N+1, X_DIM))
self.u_sol = np.zeros((N,1))
self.params = np.zeros((N+1,3))
@@ -216,30 +222,30 @@ class LongitudinalMpc():
self.set_weights_for_lead_policy()
def set_weights_for_lead_policy(self):
W = np.diag([X_EGO_OBSTACLE_COST, X_EGO_COST, V_EGO_COST, A_EGO_COST, J_EGO_COST])
Ws = np.tile(W[None], reps=(N,1,1))
self.solver.cost_set_slice(0, N, 'W', Ws, api='old')
W = np.asfortranarray(np.diag([X_EGO_OBSTACLE_COST, X_EGO_COST, V_EGO_COST, A_EGO_COST, J_EGO_COST]))
for i in range(N):
self.solver.cost_set(i, 'W', W)
# Setting the slice without the copy make the array not contiguous,
# causing issues with the C interface.
self.solver.cost_set(N, 'W', np.copy(W[:COST_E_DIM, :COST_E_DIM]))
# Set L2 slack cost on lower bound constraints
Zl = np.array([LIMIT_COST, LIMIT_COST, LIMIT_COST, DANGER_ZONE_COST])
Zls = np.tile(Zl[None], reps=(N+1,1,1))
self.solver.cost_set_slice(0, N+1, 'Zl', Zls, api='old')
for i in range(N):
self.solver.cost_set(i, 'Zl', Zl)
def set_weights_for_xva_policy(self):
W = np.diag([0., 10., 1., 10., 1.])
Ws = np.tile(W[None], reps=(N,1,1))
self.solver.cost_set_slice(0, N, 'W', Ws, api='old')
W = np.asfortranarray(np.diag([0., 10., 1., 10., 1.]))
for i in range(N):
self.solver.cost_set(i, 'W', W)
# Setting the slice without the copy make the array not contiguous,
# causing issues with the C interface.
self.solver.cost_set(N, 'W', np.copy(W[:COST_E_DIM, :COST_E_DIM]))
# Set L2 slack cost on lower bound constraints
Zl = np.array([LIMIT_COST, LIMIT_COST, LIMIT_COST, 0.0])
Zls = np.tile(Zl[None], reps=(N+1,1,1))
self.solver.cost_set_slice(0, N+1, 'Zl', Zls, api='old')
for i in range(N):
self.solver.cost_set(i, 'Zl', Zl)
def set_cur_state(self, v, a):
if abs(self.x0[1] - v) > 1.:
@@ -326,8 +332,9 @@ class LongitudinalMpc():
self.yref[:,1] = x
self.yref[:,2] = v
self.yref[:,3] = a
self.solver.cost_set_slice(0, N, "yref", self.yref[:N], api='old')
self.solver.set(N, "yref", self.yref[N][:COST_E_DIM])
for i in range(N):
self.solver.cost_set(i, "yref", self.yref[i])
self.solver.cost_set(N, "yref", self.yref[N][:COST_E_DIM])
self.accel_limit_arr[:,0] = -10.
self.accel_limit_arr[:,1] = 10.
x_obstacle = 1e5*np.ones((N+1))
@@ -338,12 +345,14 @@ class LongitudinalMpc():
def run(self):
for i in range(N+1):
self.solver.set_param(i, self.params[i])
self.solver.set(i, 'p', self.params[i])
self.solver.constraints_set(0, "lbx", self.x0)
self.solver.constraints_set(0, "ubx", self.x0)
self.solution_status = self.solver.solve()
self.solver.fill_in_slice(0, N+1, 'x', self.x_sol)
self.solver.fill_in_slice(0, N, 'u', self.u_sol)
for i in range(N+1):
self.x_sol[i] = self.solver.get(i, 'x')
for i in range(N):
self.u_sol[i] = self.solver.get(i, 'u')
self.v_solution = self.x_sol[:,1]
self.a_solution = self.x_sol[:,2]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -18,7 +18,7 @@ if [ ! -d acados_repo/ ]; then
fi
cd acados_repo
git fetch
git checkout 43ba28e95062f9ac9b48facd3b45698d57666fa3
git checkout 79e9e3e76f2751198858adf382c97837833ad31f
git submodule update --recursive --init
# build

View File

@@ -135,6 +135,8 @@ typedef struct ocp_nlp_dims
int *nz; // number of algebraic variables
int *ns; // number of slack variables
int N; // number of shooting nodes
void *raw_memory; // Pointer to allocated memory, to be used for freeing
} ocp_nlp_dims;
//
@@ -203,6 +205,9 @@ typedef struct ocp_nlp_in
/// Pointers to constraints functions (TBC).
void **constraints;
/// Pointer to allocated memory, to be used for freeing.
void *raw_memory;
} ocp_nlp_in;
//
@@ -235,6 +240,8 @@ typedef struct ocp_nlp_out
double inf_norm_res;
double total_time;
void *raw_memory; // Pointer to allocated memory, to be used for freeing
} ocp_nlp_out;
//

View File

@@ -43,44 +43,53 @@ extern "C" {
enum Newton_type_collocation
// enum Newton_type_collocation
// {
// exact = 0,
// simplified_in,
// simplified_inis
// };
// typedef struct
// {
// enum Newton_type_collocation type;
// double *eig;
// double *low_tria;
// bool single;
// bool freeze;
// double *transf1;
// double *transf2;
// double *transf1_T;
// double *transf2_T;
// } Newton_scheme;
typedef enum
{
exact = 0,
simplified_in,
simplified_inis
};
typedef struct
{
enum Newton_type_collocation type;
double *eig;
double *low_tria;
bool single;
bool freeze;
double *transf1;
double *transf2;
double *transf1_T;
double *transf2_T;
} Newton_scheme;
GAUSS_LEGENDRE,
GAUSS_RADAU_IIA,
} sim_collocation_type;
//
acados_size_t gauss_nodes_work_calculate_size(int ns);
// acados_size_t gauss_legendre_nodes_work_calculate_size(int ns);
//
void gauss_nodes(int ns, double *nodes, void *raw_memory);
// void gauss_legendre_nodes(int ns, double *nodes, void *raw_memory);
//
acados_size_t gauss_simplified_work_calculate_size(int ns);
// acados_size_t gauss_simplified_work_calculate_size(int ns);
// //
// void gauss_simplified(int ns, Newton_scheme *scheme, void *work);
//
void gauss_simplified(int ns, Newton_scheme *scheme, void *work);
acados_size_t butcher_tableau_work_calculate_size(int ns);
//
acados_size_t butcher_table_work_calculate_size(int ns);
// void calculate_butcher_tableau_from_nodes(int ns, double *nodes, double *b, double *A, void *work);
//
void butcher_table(int ns, double *nodes, double *b, double *A, void *work);
void calculate_butcher_tableau(int ns, sim_collocation_type collocation_type, double *c_vec,
double *b_vec, double *A_mat, void *work);

View File

@@ -141,12 +141,13 @@ typedef struct
bool output_z; // 1 -- if zn should be computed
bool sens_algebraic; // 1 -- if S_algebraic should be computed
bool exact_z_output; // 1 -- if z, S_algebraic should be computed exactly, extra Newton iterations
sim_collocation_type collocation_type;
// for explicit integrators: newton_iter == 0 && scheme == NULL
// && jac_reuse=false
int newton_iter;
bool jac_reuse;
Newton_scheme *scheme;
// Newton_scheme *scheme;
// workspace
void *work;

View File

@@ -40,7 +40,7 @@ extern "C" {
#include "acados/utils/types.h"
#if defined(__DSPACE__)
#if defined(__MABX2__)
double fmax(double a, double b);
#endif

View File

@@ -67,7 +67,7 @@ typedef struct acados_timer_
mach_timebase_info_data_t tinfo;
} acados_timer;
#elif defined(__DSPACE__)
#elif defined(__MABX2__)
#include <brtenv.h>

View File

@@ -227,10 +227,8 @@ int ocp_nlp_dynamics_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_n
/// \param field The name of the field, either nls_res_jac,
/// y_ref, W (others TBC)
/// \param value Cost values.
int ocp_nlp_cost_model_set_slice(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_in *in,
int start_stage, int end_stage, const char *field, void *value, int dim);
int ocp_nlp_cost_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_in *in,
int start_stage, const char *field, void *value);
int stage, const char *field, void *value);
/// Sets the function pointers to the constraints functions for the given stage.
@@ -241,8 +239,6 @@ int ocp_nlp_cost_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_i
/// \param stage Stage number.
/// \param field The name of the field, either lb, ub (others TBC)
/// \param value Constraints function or values.
int ocp_nlp_constraints_model_set_slice(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_in *in,
int start_stage, int end_stage, const char *field, void *value, int dim);
int ocp_nlp_constraints_model_set(ocp_nlp_config *config, ocp_nlp_dims *dims,
ocp_nlp_in *in, int stage, const char *field, void *value);
@@ -283,9 +279,6 @@ void ocp_nlp_out_set(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *ou
void ocp_nlp_out_get(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
int stage, const char *field, void *value);
void ocp_nlp_out_get_slice(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
int start_stage, int end_stage, const char *field, void *value);
//
void ocp_nlp_get_at_stage(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_solver *solver,
int stage, const char *field, void *value);
@@ -300,6 +293,8 @@ void ocp_nlp_constraint_dims_get_from_attr(ocp_nlp_config *config, ocp_nlp_dims
void ocp_nlp_cost_dims_get_from_attr(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
int stage, const char *field, int *dims_out);
void ocp_nlp_dynamics_dims_get_from_attr(ocp_nlp_config *config, ocp_nlp_dims *dims, ocp_nlp_out *out,
int stage, const char *field, int *dims_out);
/* opts */
@@ -374,6 +369,8 @@ int ocp_nlp_precompute(ocp_nlp_solver *solver, ocp_nlp_in *nlp_in, ocp_nlp_out *
/// \param nlp_out The output struct.
void ocp_nlp_eval_cost(ocp_nlp_solver *solver, ocp_nlp_in *nlp_in, ocp_nlp_out *nlp_out);
//
void ocp_nlp_eval_residuals(ocp_nlp_solver *solver, ocp_nlp_in *nlp_in, ocp_nlp_out *nlp_out);
//
void ocp_nlp_eval_param_sens(ocp_nlp_solver *solver, char *field, int stage, int index, ocp_nlp_out *sens_nlp_out);

View File

@@ -45,11 +45,11 @@ extern "C" {
typedef enum
{
ERK,
IRK,
GNSF,
LIFTED_IRK,
INVALID_SIM_SOLVER,
ERK,
IRK,
GNSF,
LIFTED_IRK,
INVALID_SIM_SOLVER,
} sim_solver_t;

View File

@@ -43,7 +43,28 @@
#if defined( TARGET_X64_INTEL_HASWELL )
#if defined( TARGET_X64_INTEL_SKYLAKE_X )
// common
#define CACHE_LINE_SIZE 64 // data cache size: 64 bytes
#define L1_CACHE_SIZE (32*1024) // L1 data cache size: 32 kB, 8-way
#define L2_CACHE_SIZE (256*1024) //(1024*1024) // L2 data cache size: 1 MB ; DTLB1 64*4 kB = 256 kB
#define LLC_CACHE_SIZE (6*1024*1024) //(8*1024*1024) // LLC cache size: 8 MB ; TLB 1536*4 kB = 6 MB
// double
#define D_PS 8 // panel size
#define D_PLD 8 // 4 // GCD of panel length
#define D_M_KERNEL 24 // max kernel size
#define D_KC 128 //256 // 192
#define D_NC 144 //72 //96 //72 // 120 // 512
#define D_MC 2400 // 6000
// single
#define S_PS 16 // panel size
#define S_PLD 4 // GCD of panel length TODO probably 16 when writing assebly
#define S_M_KERNEL 32 // max kernel size
#define S_KC 128 //256
#define S_NC 128 //144
#define S_MC 3000
#elif defined( TARGET_X64_INTEL_HASWELL )
// common
#define CACHE_LINE_SIZE 64 // data cache size: 64 bytes
#define L1_CACHE_SIZE (32*1024) // L1 data cache size: 32 kB, 8-way

View File

@@ -101,18 +101,18 @@ void blasfeo_dtrsv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct bl
void blasfeo_dtrsv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= inv( A^T ) * x, A (m)x(m) upper, transposed, not_unit
void blasfeo_dtrsv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_dtrmv_lnn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular, unit diagonal
void blasfeo_dtrmv_lnu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A^T * x ; A lower triangular
void blasfeo_dtrmv_ltn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A^T * x ; A lower triangular, unit diagonal
void blasfeo_dtrmv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= beta * y + alpha * A * x ; A upper triangular
void blasfeo_dtrmv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A^T * x ; A upper triangular
void blasfeo_dtrmv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_dtrmv_lnn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A^T * x ; A lower triangular
void blasfeo_dtrmv_ltn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular, unit diagonal
void blasfeo_dtrmv_lnu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A^T * x ; A lower triangular, unit diagonal
void blasfeo_dtrmv_ltu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z_n <= beta_n * y_n + alpha_n * A * x_n
// z_t <= beta_t * y_t + alpha_t * A^T * x_t
void blasfeo_dgemv_nt(int m, int n, double alpha_n, double alpha_t, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx_n, int xi_n, struct blasfeo_dvec *sx_t, int xi_t, double beta_n, double beta_t, struct blasfeo_dvec *sy_n, int yi_n, struct blasfeo_dvec *sy_t, int yi_t, struct blasfeo_dvec *sz_n, int zi_n, struct blasfeo_dvec *sz_t, int zi_t);

View File

@@ -101,23 +101,24 @@ void blasfeo_ref_dtrsv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struc
void blasfeo_ref_dtrsv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= inv( A' ) * x, A (m)x(m) upper, transposed, not_unit
void blasfeo_ref_dtrsv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_ref_dtrmv_lnn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular, unit diagonal
void blasfeo_ref_dtrmv_lnu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A' * x ; A lower triangular
void blasfeo_ref_dtrmv_ltn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A' * x ; A lower triangular, unit diagonal
void blasfeo_ref_dtrmv_ltu(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= beta * y + alpha * A * x ; A upper triangular
void blasfeo_ref_dtrmv_unn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A' * x ; A upper triangular
void blasfeo_ref_dtrmv_utn(int m, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_ref_dtrmv_lnn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A' * x ; A lower triangular
void blasfeo_ref_dtrmv_ltn(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A * x ; A lower triangular, unit diagonal
void blasfeo_ref_dtrmv_lnu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z <= A' * x ; A lower triangular, unit diagonal
void blasfeo_ref_dtrmv_ltu(int m, int n, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, struct blasfeo_dvec *sz, int zi);
// z_n <= beta_n * y_n + alpha_n * A * x_n
// z_t <= beta_t * y_t + alpha_t * A' * x_t
void blasfeo_ref_dgemv_nt(int m, int n, double alpha_n, double alpha_t, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx_n, int xi_n, struct blasfeo_dvec *sx_t, int xi_t, double beta_n, double beta_t, struct blasfeo_dvec *sy_n, int yi_n, struct blasfeo_dvec *sy_t, int yi_t, struct blasfeo_dvec *sz_n, int zi_n, struct blasfeo_dvec *sz_t, int zi_t);
// z <= beta * y + alpha * A * x, where A is symmetric and only the lower triangular patr of A is accessed
void blasfeo_ref_dsymv_l(int m, int n, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, double beta, struct blasfeo_dvec *sy, int yi, struct blasfeo_dvec *sz, int zi);
void blasfeo_ref_dsymv_l(int m, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, double beta, struct blasfeo_dvec *sy, int yi, struct blasfeo_dvec *sz, int zi);
void blasfeo_ref_dsymv_l_mn(int m, int n, double alpha, struct blasfeo_dmat *sA, int ai, int aj, struct blasfeo_dvec *sx, int xi, double beta, struct blasfeo_dvec *sy, int yi, struct blasfeo_dvec *sz, int zi);
// diagonal

View File

@@ -54,6 +54,158 @@ void blasfeo_align_4096_byte(void *ptr, void **ptr_align);
void blasfeo_align_64_byte(void *ptr, void **ptr_align);
//
// lib8
//
// 24x8
void kernel_dgemm_nt_24x8_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd); //
void kernel_dgemm_nt_24x8_vs_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
void kernel_dtrsm_nt_rl_inv_24x8_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E); //
void kernel_dpotrf_nt_l_24x8_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
void kernel_dpotrf_nt_l_24x8_vs_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
void kernel_dtrsm_nt_rl_inv_24x8_vs_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1); //
void kernel_dgemm_dtrsm_nt_rl_inv_24x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
void kernel_dgemm_dtrsm_nt_rl_inv_24x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1);
void kernel_dsyrk_dpotrf_nt_l_24x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_24x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
void kernel_dlarfb8_rn_24_lib8(int kmax, double *pV, double *pT, double *pD, int sdd);
// 16x8
void kernel_dgemm_nt_16x8_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd); //
void kernel_dgemm_nt_16x8_vs_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
void kernel_dgemm_nt_16x8_gen_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dgemm_nn_16x8_lib8(int k, double *alpha, double *A, int sda, int offB, double *B, int sdb, double *beta, double *C, int sdc, double *D, int sdd); //
void kernel_dgemm_nn_16x8_vs_lib8(int k, double *alpha, double *A, int sda, int offB, double *B, int sdb, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
void kernel_dgemm_nn_16x8_gen_lib8(int k, double *alpha, double *A, int sda, int offB, double *B, int sdb, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dsyrk_nt_l_16x8_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd); //
void kernel_dsyrk_nt_l_16x8_vs_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, int m1, int n1); //
void kernel_dsyrk_nt_l_16x8_gen_lib8(int k, double *alpha, double *A, int sda, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dtrmm_nn_rl_16x8_lib8(int k, double *alpha, double *A, int sda, int offsetB, double *B, int sdb, double *D, int sdd);
void kernel_dtrmm_nn_rl_16x8_vs_lib8(int k, double *alpha, double *A, int sda, int offsetB, double *B, int sdb, double *D, int sdd, int m1, int n1);
void kernel_dtrmm_nn_rl_16x8_gen_lib8(int k, double *alpha, double *A, int sda, int offsetB, double *B, int sdb, int offD, double *D, int sdd, int m0, int m1, int n0, int n1);
void kernel_dtrsm_nt_rl_inv_16x8_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E); //
void kernel_dtrsm_nt_rl_inv_16x8_vs_lib8(int k, double *A, int sda, double *B, double *beta, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1); //
void kernel_dpotrf_nt_l_16x8_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
void kernel_dpotrf_nt_l_16x8_vs_lib8(int k, double *A, int sda, double *B, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
void kernel_dgemm_dtrsm_nt_rl_inv_16x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
void kernel_dgemm_dtrsm_nt_rl_inv_16x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int m1, int n1);
void kernel_dsyrk_dpotrf_nt_l_16x8_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_16x8_vs_lib8(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int m1, int n1);
void kernel_dlarfb8_rn_16_lib8(int kmax, double *pV, double *pT, double *pD, int sdd);
void kernel_dlarfb8_rn_la_16_lib8(int n1, double *pVA, double *pT, double *pD, int sdd, double *pA, int sda);
void kernel_dlarfb8_rn_lla_16_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, int sdd, double *pL, int sdl, double *pA, int sda);
// 8x16
void kernel_dgemm_tt_8x16_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, int sdb, double *beta, double *C, double *D); //
void kernel_dgemm_tt_8x16_vs_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, int sdb, double *beta, double *C, double *D, int m1, int n1); //
void kernel_dgemm_tt_8x16_gen_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, int sdb, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dgemm_nt_8x16_lib8(int k, double *alpha, double *A, double *B, int sdb, double *beta, double *C, double *D); //
void kernel_dgemm_nt_8x16_vs_lib8(int k, double *alpha, double *A, double *B, int sdb, double *beta, double *C, double *D, int m1, int n1); //
// 8x8
void kernel_dgemm_nt_8x8_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D); //
void kernel_dgemm_nt_8x8_vs_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D, int m1, int n1); //
void kernel_dgemm_nt_8x8_gen_lib8(int k, double *alpha, double *A, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dgemm_nn_8x8_lib8(int k, double *alpha, double *A, int offB, double *B, int sdb, double *beta, double *C, double *D); //
void kernel_dgemm_nn_8x8_vs_lib8(int k, double *alpha, double *A, int offB, double *B, int sdb, double *beta, double *C, double *D, int m1, int n1); //
void kernel_dgemm_nn_8x8_gen_lib8(int k, double *alpha, double *A, int offB, double *B, int sdb, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dgemm_tt_8x8_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, double *beta, double *C, double *D); //
void kernel_dgemm_tt_8x8_vs_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, double *beta, double *C, double *D, int m1, int n1); //
void kernel_dgemm_tt_8x8_gen_lib8(int k, double *alpha, int offA, double *A, int sda, double *B, double *beta, int offc, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dsyrk_nt_l_8x8_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D); //
void kernel_dsyrk_nt_l_8x8_vs_lib8(int k, double *alpha, double *A, double *B, double *beta, double *C, double *D, int m1, int n1); //
void kernel_dsyrk_nt_l_8x8_gen_lib8(int k, double *alpha, double *A, double *B, double *beta, int offC, double *C, int sdc, int offD, double *D, int sdd, int m0, int m1, int n0, int n1); //
void kernel_dtrmm_nn_rl_8x8_lib8(int k, double *alpha, double *A, int offsetB, double *B, int sdb, double *D);
void kernel_dtrmm_nn_rl_8x8_vs_lib8(int k, double *alpha, double *A, int offsetB, double *B, int sdb, double *D, int m1, int n1);
void kernel_dtrmm_nn_rl_8x8_gen_lib8(int k, double *alpha, double *A, int offsetB, double *B, int sdb, int offD, double *D, int sdd, int m0, int m1, int n0, int n1);
void kernel_dtrsm_nt_rl_inv_8x8_lib8(int k, double *A, double *B, double *beta, double *C, double *D, double *E, double *inv_diag_E);
void kernel_dtrsm_nt_rl_inv_8x8_vs_lib8(int k, double *A, double *B, double *beta, double *C, double *D, double *E, double *inv_diag_E, int m1, int n1);
void kernel_dpotrf_nt_l_8x8_lib8(int k, double *A, double *B, double *C, double *D, double *inv_diag_D);
void kernel_dpotrf_nt_l_8x8_vs_lib8(int k, double *A, double *B, double *C, double *D, double *inv_diag_D, int m1, int n1);
void kernel_dgemm_dtrsm_nt_rl_inv_8x8_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E);
void kernel_dgemm_dtrsm_nt_rl_inv_8x8_vs_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E, int m1, int n1);
void kernel_dsyrk_dpotrf_nt_l_8x8_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_8x8_vs_lib8(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int m1, int n1);
void kernel_dgelqf_vs_lib8(int m, int n, int k, int offD, double *pD, int sdd, double *dD);
void kernel_dgelqf_pd_vs_lib8(int m, int n, int k, int offD, double *pD, int sdd, double *dD);
void kernel_dgelqf_8_lib8(int kmax, double *pD, double *dD);
void kernel_dgelqf_pd_8_lib8(int kmax, double *pD, double *dD);
void kernel_dlarft_8_lib8(int kmax, double *pD, double *dD, double *pT);
void kernel_dlarfb8_rn_8_lib8(int kmax, double *pV, double *pT, double *pD);
void kernel_dlarfb8_rn_8_vs_lib8(int kmax, double *pV, double *pT, double *pD, int m1);
void kernel_dlarfb8_rn_1_lib8(int kmax, double *pV, double *pT, double *pD);
void kernel_dgelqf_dlarft8_8_lib8(int kmax, double *pD, double *dD, double *pT);
void kernel_dgelqf_pd_dlarft8_8_lib8(int kmax, double *pD, double *dD, double *pT);
void kernel_dgelqf_pd_la_vs_lib8(int m, int n1, int k, int offD, double *pD, int sdd, double *dD, int offA, double *pA, int sda);
void kernel_dgelqf_pd_la_dlarft8_8_lib8(int kmax, double *pD, double *dD, double *pA, double *pT);
void kernel_dlarft_la_8_lib8(int n1, double *dD, double *pA, double *pT);
void kernel_dlarfb8_rn_la_8_lib8(int n1, double *pVA, double *pT, double *pD, double *pA);
void kernel_dlarfb8_rn_la_8_vs_lib8(int n1, double *pVA, double *pT, double *pD, double *pA, int m1);
void kernel_dlarfb8_rn_la_1_lib8(int n1, double *pVA, double *pT, double *pD, double *pA);
void kernel_dgelqf_pd_lla_vs_lib8(int m, int n0, int n1, int k, int offD, double *pD, int sdd, double *dD, int offL, double *pL, int sdl, int offA, double *pA, int sda);
void kernel_dgelqf_pd_lla_dlarft8_8_lib8(int n0, int n1, double *pD, double *dD, double *pL, double *pA, double *pT);
void kernel_dlarft_lla_8_lib8(int n0, int n1, double *dD, double *pL, double *pA, double *pT);
void kernel_dlarfb8_rn_lla_8_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, double *pL, double *pA);
void kernel_dlarfb8_rn_lla_8_vs_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, double *pL, double *pA, int m1);
void kernel_dlarfb8_rn_lla_1_lib8(int n0, int n1, double *pVL, double *pVA, double *pT, double *pD, double *pL, double *pA);
// panel copy / pack
// 24
void kernel_dpack_nn_24_lib8(int kmax, double *A, int lda, double *C, int sdc);
void kernel_dpack_nn_24_vs_lib8(int kmax, double *A, int lda, double *C, int sdc, int m1);
// 16
void kernel_dpacp_nn_16_lib8(int kmax, int offsetA, double *A, int sda, double *B, int sdb);
void kernel_dpacp_nn_16_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int sdb, int m1);
void kernel_dpack_nn_16_lib8(int kmax, double *A, int lda, double *C, int sdc);
void kernel_dpack_nn_16_vs_lib8(int kmax, double *A, int lda, double *C, int sdc, int m1);
// 8
void kernel_dpacp_nn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
void kernel_dpacp_nn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
void kernel_dpacp_tn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
void kernel_dpacp_tn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
void kernel_dpacp_l_nn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
void kernel_dpacp_l_nn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
void kernel_dpacp_l_tn_8_lib8(int kmax, int offsetA, double *A, int sda, double *B);
void kernel_dpacp_l_tn_8_vs_lib8(int kmax, int offsetA, double *A, int sda, double *B, int m1);
void kernel_dpaad_nn_8_lib8(int kmax, double *alpha, int offsetA, double *A, int sda, double *B);
void kernel_dpaad_nn_8_vs_lib8(int kmax, double *alpha, int offsetA, double *A, int sda, double *B, int m1);
void kernel_dpack_nn_8_lib8(int kmax, double *A, int lda, double *C);
void kernel_dpack_nn_8_vs_lib8(int kmax, double *A, int lda, double *C, int m1);
void kernel_dpack_tn_8_lib8(int kmax, double *A, int lda, double *C);
void kernel_dpack_tn_8_vs_lib8(int kmax, double *A, int lda, double *C, int m1);
// 4
void kernel_dpack_tt_4_lib8(int kmax, double *A, int lda, double *C, int sdc); // TODO offsetC
void kernel_dpack_tt_4_vs_lib8(int kmax, double *A, int lda, double *C, int sdc, int m1); // TODO offsetC
// level 2 BLAS
// 16
void kernel_dgemv_n_16_lib8(int k, double *alpha, double *A, int sda, double *x, double *beta, double *y, double *z);
// 8
void kernel_dgemv_n_8_lib8(int k, double *alpha, double *A, double *x, double *beta, double *y, double *z);
void kernel_dgemv_n_8_vs_lib8(int k, double *alpha, double *A, double *x, double *beta, double *y, double *z, int m1);
//void kernel_dgemv_n_8_gen_lib8(int k, double *alpha, double *A, double *x, double *beta, double *y, double *z, int m0, int m1);
void kernel_dgemv_n_8_gen_lib8(int k, double *alpha, int offsetA, double *A, double *x, double *beta, double *y, double *z, int m1);
void kernel_dgemv_t_8_lib8(int k, double *alpha, int offsetA, double *A, int sda, double *x, double *beta, double *y, double *z);
void kernel_dgemv_t_8_vs_lib8(int k, double *alpha, int offsetA, double *A, int sda, double *x, double *beta, double *y, double *z, int n1);
void kernel_dgemv_nt_8_lib8(int kmax, double *alpha_n, double *alpha_t, int offsetA, double *A, int sda, double *x_n, double *x_t, double *beta_t, double *y_t, double *z_n, double *z_t);
void kernel_dgemv_nt_8_vs_lib8(int kmax, double *alpha_n, double *alpha_t, int offsetA, double *A, int sda, double *x_n, double *x_t, double *beta_t, double *y_t, double *z_n, double *z_t, int n1);
void kernel_dsymv_l_8_lib8(int kmax, double *alpha, double *A, int sda, double *x, double *z);
void kernel_dsymv_l_8_vs_lib8(int kmax, double *alpha, double *A, int sda, double *x, double *z, int n1);
void kernel_dsymv_l_8_gen_lib8(int kmax, double *alpha, int offsetA, double *A, int sda, double *x, double *z, int n1);
void kernel_dtrmv_n_ln_8_lib8(int k, double *A, double *x, double *z);
void kernel_dtrmv_n_ln_8_vs_lib8(int k, double *A, double *x, double *z, int m1);
void kernel_dtrmv_n_ln_8_gen_lib8(int k, int offsetA, double *A, double *x, double *z, int m1);
void kernel_dtrmv_t_ln_8_lib8(int k, double *A, int sda, double *x, double *z);
void kernel_dtrmv_t_ln_8_vs_lib8(int k, double *A, int sda, double *x, double *z, int n1);
void kernel_dtrmv_t_ln_8_gen_lib8(int k, int offsetA, double *A, int sda, double *x, double *z, int n1);
void kernel_dtrsv_n_l_inv_8_lib8(int k, double *A, double *inv_diag_A, double *x, double *z);
void kernel_dtrsv_n_l_inv_8_vs_lib8(int k, double *A, double *inv_diag_A, double *x, double *z, int m1, int n1);
void kernel_dtrsv_t_l_inv_8_lib8(int k, double *A, int sda, double *inv_diag_A, double *x, double *z);
void kernel_dtrsv_t_l_inv_8_vs_lib8(int k, double *A, int sda, double *inv_diag_A, double *x, double *z, int m1, int n1);
//
// lib4
//
// level 2 BLAS
// 12
@@ -413,10 +565,10 @@ void kernel_drowsw_lib4(int kmax, double *pA, double *pC);
// merged routines
// 12x4
void kernel_dgemm_dtrsm_nt_rl_inv_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
void kernel_dgemm_dtrsm_nt_rl_inv_12x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
void kernel_dsyrk_dpotrf_nt_l_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int km, int kn);
void kernel_dgemm_dtrsm_nt_rl_inv_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
void kernel_dsyrk_dpotrf_nt_l_12x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_12x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int km, int kn);
// 4x12
void kernel_dgemm_dtrsm_nt_rl_inv_4x12_vs_lib4(int kp, double *Ap, double *Bp, int sdbp, int km_, double *Am, double *Bm, int sdbm, double *C, double *D, double *E, int sde, double *inv_diag_E, int km, int kn);
// 8x8
@@ -425,8 +577,8 @@ void kernel_dsyrk_dpotrf_nt_l_8x8_vs_lib4(int kp, double *Ap, int sdap, double *
void kernel_dgemm_dtrsm_nt_rl_inv_8x8l_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int sdb, int km_, double *Am, int sdam, double *Bm, int sdbm, double *C, int sdc, double *D, int sdd, double *E, int sde, double *inv_diag_E, int km, int kn);
void kernel_dgemm_dtrsm_nt_rl_inv_8x8u_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int sdb, int km_, double *Am, int sdam, double *Bm, int sdbm, double *C, int sdc, double *D, int sdd, double *E, int sde, double *inv_diag_E, int km, int kn);
// 8x4
void kernel_dgemm_dtrsm_nt_rl_inv_8x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
void kernel_dgemm_dtrsm_nt_rl_inv_8x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E);
void kernel_dgemm_dtrsm_nt_rl_inv_8x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *E, double *inv_diag_E, int km, int kn);
void kernel_dsyrk_dpotrf_nt_l_8x4_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_8x4_vs_lib4(int kp, double *Ap, int sdap, double *Bp, int km_, double *Am, int sdam, double *Bm, double *C, int sdc, double *D, int sdd, double *inv_diag_D, int km, int kn);
// 4x8
@@ -434,16 +586,16 @@ void kernel_dgemm_dtrsm_nt_rl_inv_4x8_vs_lib4(int kp, double *Ap, double *Bp, in
// 4x4
void kernel_dgemm_dtrsm_nt_rl_inv_4x4_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E);
void kernel_dgemm_dtrsm_nt_rl_inv_4x4_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E, int km, int kn);
void kernel_dsyrk_dpotrf_nt_l_4x4_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
void kernel_dsyrk_dpotrf_nt_l_4x4_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_4x4_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
// 4x2
void kernel_dgemm_dtrsm_nt_rl_inv_4x2_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E);
void kernel_dgemm_dtrsm_nt_rl_inv_4x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *E, double *inv_diag_E, int km, int kn);
void kernel_dsyrk_dpotrf_nt_l_4x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
void kernel_dsyrk_dpotrf_nt_l_4x2_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_4x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
// 2x2
void kernel_dsyrk_dpotrf_nt_l_2x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
void kernel_dsyrk_dpotrf_nt_l_2x2_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D);
void kernel_dsyrk_dpotrf_nt_l_2x2_vs_lib4(int kp, double *Ap, double *Bp, int km_, double *Am, double *Bm, double *C, double *D, double *inv_diag_D, int km, int kn);
/*
*
@@ -1034,6 +1186,53 @@ void kernel_dgemm_nt_8xn_p0_lib44cc(int n, int k, double *alpha, double *A, int
// A, B panel-major bs=8; C, D column-major
// 24x8
void kernel_dgemm_nt_24x8_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_24x8_vs_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// 16x8
void kernel_dgemm_nt_16x8_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_16x8_vs_lib88cc(int kmax, double *alpha, double *A, int sda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// 8x8
void kernel_dgemm_nt_8x8_lib88cc(int kmax, double *alpha, double *A, double *B, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_8x8_vs_lib88cc(int kmax, double *alpha, double *A, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// A, panel-major bs=8; B, C, D column-major
// 24x8
void kernel_dgemm_nt_24x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_24x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
void kernel_dgemm_nn_24x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nn_24x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// 16x8
void kernel_dgemm_nt_16x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_16x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
void kernel_dgemm_nn_16x8_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nn_16x8_vs_lib8ccc(int kmax, double *alpha, double *A, int sda, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// 8x8
void kernel_dgemm_nt_8x8_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_8x8_vs_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
void kernel_dgemm_nn_8x8_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nn_8x8_vs_lib8ccc(int kmax, double *alpha, double *A, double *B, int ldb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// B, panel-major bs=8; A, C, D column-major
// 8x24
void kernel_dgemm_nt_8x24_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_8x24_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
void kernel_dgemm_tt_8x24_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_tt_8x24_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// 8x16
void kernel_dgemm_nt_8x16_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_8x16_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
void kernel_dgemm_tt_8x16_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_tt_8x16_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, int sdb, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// 8x8
void kernel_dgemm_nt_8x8_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_nt_8x8_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
void kernel_dgemm_tt_8x8_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd);
void kernel_dgemm_tt_8x8_vs_libc8cc(int kmax, double *alpha, double *A, int lda, double *B, double *beta, double *C, int ldc, double *D, int ldd, int m1, int n1);
// aux
void kernel_dvecld_inc1(int kmax, double *x);
void kernel_dveccp_inc1(int kmax, double *x, double *y);

View File

@@ -97,14 +97,18 @@ void blasfeo_strsv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struct bl
void blasfeo_strsv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= inv( A' ) * x, A (m)x(m) upper, transposed, not_unit
void blasfeo_strsv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_strmv_lnn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A * x ; A lower triangular, unit diagonal
void blasfeo_strmv_lnu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A lower triangular
void blasfeo_strmv_ltn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A lower triangular, unit diagonal
void blasfeo_strmv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= beta * y + alpha * A * x ; A upper triangular
void blasfeo_strmv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A upper triangular
void blasfeo_strmv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_strmv_lnn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A lower triangular
void blasfeo_strmv_ltn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z_n <= beta_n * y_n + alpha_n * A * x_n
// z_t <= beta_t * y_t + alpha_t * A' * x_t
void blasfeo_sgemv_nt(int m, int n, float alpha_n, float alpha_t, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx_n, int xi_n, struct blasfeo_svec *sx_t, int xi_t, float beta_n, float beta_t, struct blasfeo_svec *sy_n, int yi_n, struct blasfeo_svec *sy_t, int yi_t, struct blasfeo_svec *sz_n, int zi_n, struct blasfeo_svec *sz_t, int zi_t);

View File

@@ -97,19 +97,24 @@ void blasfeo_ref_strsv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struc
void blasfeo_ref_strsv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= inv( A' ) * x, A (m)x(m) upper, transposed, not_unit
void blasfeo_ref_strsv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_ref_strmv_lnn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A * x ; A lower triangular, unit diagonal
void blasfeo_ref_strmv_lnu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A lower triangular
void blasfeo_ref_strmv_ltn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A lower triangular, unit diagonal
void blasfeo_ref_strmv_ltu(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= beta * y + alpha * A * x ; A upper triangular
void blasfeo_ref_strmv_unn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A upper triangular
void blasfeo_ref_strmv_utn(int m, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A * x ; A lower triangular
void blasfeo_ref_strmv_lnn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z <= A' * x ; A lower triangular
void blasfeo_ref_strmv_ltn(int m, int n, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, struct blasfeo_svec *sz, int zi);
// z_n <= beta_n * y_n + alpha_n * A * x_n
// z_t <= beta_t * y_t + alpha_t * A' * x_t
void blasfeo_ref_sgemv_nt(int m, int n, float alpha_n, float alpha_t, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx_n, int xi_n, struct blasfeo_svec *sx_t, int xi_t, float beta_n, float beta_t, struct blasfeo_svec *sy_n, int yi_n, struct blasfeo_svec *sy_t, int yi_t, struct blasfeo_svec *sz_n, int zi_n, struct blasfeo_svec *sz_t, int zi_t);
// z <= beta * y + alpha * A * x, where A is symmetric and only the lower triangular patr of A is accessed
void blasfeo_ref_ssymv_l(int m, int n, float alpha, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, float beta, struct blasfeo_svec *sy, int yi, struct blasfeo_svec *sz, int zi);
void blasfeo_ref_ssymv_l(int m, float alpha, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, float beta, struct blasfeo_svec *sy, int yi, struct blasfeo_svec *sz, int zi);
void blasfeo_ref_ssymv_l_mn(int m, int n, float alpha, struct blasfeo_smat *sA, int ai, int aj, struct blasfeo_svec *sx, int xi, float beta, struct blasfeo_svec *sy, int yi, struct blasfeo_svec *sz, int zi);
// diagonal

View File

@@ -550,6 +550,16 @@ void kernel_spotrf_nt_l_4x4_lib44cc(int kmax, float *A, float *B, float *C, int
void kernel_spotrf_nt_l_4x4_vs_lib44cc(int kmax, float *A, float *B, float *C, int ldc, float *D, int ldd, float *dD, int m1, int n1);
// B panel-major bs=8; A, C, D column-major
// 4x24
void kernel_sgemm_nt_4x24_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
void kernel_sgemm_nt_4x24_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
void kernel_sgemm_tt_4x24_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
void kernel_sgemm_tt_4x24_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
// 4x16
void kernel_sgemm_nt_4x16_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
void kernel_sgemm_nt_4x16_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
void kernel_sgemm_tt_4x16_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd);
void kernel_sgemm_tt_4x16_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, int sdb, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);
// 8x8
void kernel_sgemm_nt_8x8_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, float *beta, float *C, int ldc, float *D, int ldd);
void kernel_sgemm_nt_8x8_vs_libc8cc(int kmax, float *alpha, float *A, int lda, float *B, float *beta, float *C, int ldc, float *D, int ldd, int m1, int n1);

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.