使用Cython包装LAPACKE函数

前端之家收集整理的这篇文章主要介绍了使用Cython包装LAPACKE函数前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

我正在尝试使用Cython包装LAPACK函数dgtsv(三对角方程组的求解器).

我遇到了this previous answer,但由于dgtsv不是scipy.linalg中包含的LAPACK函数之一,我认为我不能使用这种特殊的方法.相反,我一直在努力追随this example.

这是我的lapacke.pxd文件内容

@H_404_9@ctypedef int lapack_int cdef extern from "lapacke.h" nogil: int LAPACK_ROW_MAJOR int LAPACK_COL_MAJOR lapack_int LAPACKE_dgtsv(int matrix_order,lapack_int n,lapack_int nrhs,double * dl,double * d,double * du,double * b,lapack_int ldb)

…这是我在_solvers.pyx中的瘦Cython包装器:

@H_404_9@#!python cimport cython from lapacke cimport * cpdef TDMA_lapacke(double[::1] DL,double[::1] D,double[::1] DU,double[:,::1] B): cdef: lapack_int n = D.shape[0] lapack_int nrhs = B.shape[1] lapack_int ldb = B.shape[0] double * dl = &DL[0] double * d = &D[0] double * du = &DU[0] double * b = &B[0,0] lapack_int info info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR,n,nrhs,dl,d,du,b,ldb) return info

…这是一个Python包装器和测试脚本:

@H_404_9@import numpy as np from scipy import sparse from cymodules import _solvers def trisolve_lapacke(dl,inplace=False): if (dl.shape[0] != du.shape[0] or dl.shape[0] != d.shape[0] - 1 or b.shape != d.shape): raise ValueError('Invalid diagonal shapes') if b.ndim == 1: # b is (LDB,NRHS) b = b[:,None] # be sure to force a copy of d and b if we're not solving in place if not inplace: d = d.copy() b = b.copy() # this may also force copies if arrays are improperly typed/noncontiguous dl,b = (np.ascontiguousarray(v,dtype=np.float64) for v in (dl,b)) # b will now be modified in place to contain the solution info = _solvers.TDMA_lapacke(dl,b) print info return b.ravel() def test_trisolve(n=20000): dl = np.random.randn(n - 1) d = np.random.randn(n) du = np.random.randn(n - 1) M = sparse.diags((dl,du),(-1,1),format='csc') x = np.random.randn(n) b = M.dot(x) x_hat = trisolve_lapacke(dl,b) print "||x - x_hat|| = ",np.linalg.norm(x - x_hat)

不幸的是,test_trisolve只是对_solvers.TDMA_lapacke调用的段错误.
我很确定我的setup.py是正确的 – ldd _solvers.so显示_solvers.so在运行时链接到正确的共享库.

我不确定如何从这里开始 – 任何想法?

简要更新:

对于较小的n值,我倾向于不立即得到段错误,但我确实得到了无意义的结果(|| x – x_hat ||应该非常接近0):

@H_404_9@In [28]: test_trisolve2.test_trisolve(10) 0 ||x - x_hat|| = 6.23202576396 In [29]: test_trisolve2.test_trisolve(10) -7 ||x - x_hat|| = 3.88623414288 In [30]: test_trisolve2.test_trisolve(10) 0 ||x - x_hat|| = 2.60190676562 In [31]: test_trisolve2.test_trisolve(10) 0 ||x - x_hat|| = 3.86631743386 In [32]: test_trisolve2.test_trisolve(10) Segmentation fault

通常LAPACKE_dgtsv返回代码0(应该表示成功),但偶尔我得到-7,这意味着参数7(b)具有非法值.发生的事情是,只有b的第一个值实际上被修改了.如果我继续调用test_trisolve,即使n很小,我最终也会遇到段错误.

最佳答案
好吧,我最终想通了 – 似乎我误解了在这种情况下行和列主要引用的内容.

由于C连续数组遵循行主顺序,我假设我应该将LAPACK_ROW_MAJOR指定为LAPACKE_dgtsv的第一个参数.

事实上,如果我改变

@H_404_9@info = LAPACKE_dgtsv(LAPACK_ROW_MAJOR,...)

@H_404_9@info = LAPACKE_dgtsv(LAPACK_COL_MAJOR,...)

然后我的功能工作:

@H_404_9@test_trisolve2.test_trisolve() 0 ||x - x_hat|| = 6.67064747632e-12

这对我来说似乎很反直 – 有人可以解释为什么会这样吗?

原文链接:https://www.f2er.com/python/439765.html

猜你在找的Python相关文章