cython 的认识

经常听到 cpython 的概念,这个概念和 python 有什么联系呢?

什么是 cython?

  • Cython 是一个快速生成 Python 扩展模块的工具
  • 从语法层面上来讲是 Python 语法和 C 语言语法的混血
  • 当 Python 性能遇到瓶颈时,Cython 直接将 C 的原生速度植入 Python 程序,这样使 Python 程序无需使用 C 重写,能快速整合原有的 Python 程序,这样使得开发效率和执行效率都有很大的提高,而这些中间的部分,都是 Cython 帮我们做了

cython 有哪些优势?

  • 编写 python 代码,通过 cython,可以方便地嵌入到 C/C++ 代码中
  • 通过静态类型声明,将 Python 代码运行速度提升到 C 语言级别
  • 使用组合的源代码级别 debugging,方便在 Python、C、C++ 之间调试
  • 有效与大型数据集交互,比如 Numpy

什么是 cython 编程语言?

  • cython 编程语言是 python 编程语言的超集
  • 还支持调用 C 函数和在变量和类属性上声明 C 类型

Cython 与 CPython 有什么区别?

  • Cython 是一种混合编程语言,可以让 Python 调用 C++ 的容器,例如 vector
  • CPython 是一种被广泛使用的 python 解释器,类似的解释器还有 pypy,JPython 等等

.py、.pxd、.pyx、pyc、pyo、.pyd 分别是什么文件?

  • .py : python 的源代码文件
  • .pxd: 文件是由 Cython 编程语言 “编写” 而成的 Python 扩展模块头文件,类似于 C 语言的 .h 头文件,.pxd 文件中有 Cython 模块要包含的 Cython 声明 (或代码段) , 可共享外部 C 语言声明,也能包含 C 编译器内联函数。.pxd 文件还可为 .pyx 文件模块提供 Cython 接口,以便其它 Cython 模块可使用比 Python 更高效的协议与之进行通信
  • .pyx : 由 Cython 编程语言 “编写” 而成的 Python 扩展模块源代码文件,类似于 C 语言的 .c 源代码文件,.pyx 文件中有 Cython 模块的源代码,不像 Python 语言可直接解释使用的 .py 文件,.pyx 文件必须先被编译成 .c 文件,再编译成 .pyd (Windows 平台) 或 .so (Linux 平台) 文件,才可作为模块 import 导入使用
  • .pyc: Python 源代码 import 后,编译生成的字节码
  • .pyo: Python 源代码编译优化生成的字节码。pyo 比 pyc 并没有优化多少,只是去掉了断言
  • .pyd: 非 Python,由其它编程语言 “编写 - 编译” 生成的 Python 扩展模块,相当于 windows 下的 Python 动态链接库
  • .py, .pyc, .pyo 运行速度几乎无差别,只是 pyc, pyo 文件加载的速度更快,但不能用文本编辑器查看内容,反编译不太容易

cython 的工作流程?

如何获得.pyx 文件?

  • 如果该文件不含 cython 特有的语法,可以直接将.py 文件替换为.pyx 文件
  • 如果含有 cython 特有语法,则需要手动编写 cython 语言

cython 的基本使用方法?

  • 使用 Cython 编译 Python 代码时,务必要安装 C/C++ 编译器(windows 可以直接安装 Visiual Studio 的开发环境)
  • 安装 Cython 库
    1
    pip install Cython
  • 编写.pyx 文件,如 test.pyx(或者直接写.py 也可以?只要文件里不使用 cpython 独有的语法,而是使用纯 python 语法)
    1
    2
    def say_hello():
    print("hello world")
  • 编写 setup.py 文件
    1
    2
    3
    4
    5
    # cythonize()是Cython提供将Python代码转换成C代码的API
    # setup是Python提供的一种发布Python模块的方法
    from distutils.core import setup
    from Cython.Build import cythonize
    setup(ext_modules = cythonize("test.py"))
  • 编译
    1
    2
    3
    # build_ext是指明python生成C/C++的扩展模块(build C/C++ extensions (compile/link to build directory))
    # --inplace指示 将编译后的扩展模块直接放在与test.py同级的目录中
    python setup.py build_ext --inplace
  • 使用
    1
    2
    3
    # 像使用Python的任意模块一样,直接import即可
    import test
    test.say_hello()

使用 Cython 优化自定义函数?

  • 创建.pyx 文件
    1
    cp compute.py compute1.pyx 
  • 创建 setup.py
    1
    2
    3
    4
    5
    6
    7
    # setup1.py
    from distutils.core import setup
    from Cython.Build import cythonize
    setup(
    name='compute_module',
    ext_modules=cythonize('compute1.pyx'),
    )
  • 编译.pyx 文件
    1
    python setup1.py build

使用案例比较说明 Cython 和 Python 程序性能提升?

  • 自定义函数
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    # compute.py
    # coding=utf-8
    import math
    def spherical_distance(lon1, lat1, lon2, lat2):
    radius = 3956
    x = math.pi/180.0
    a = (90.0 - lat1)*x
    b = (90.0 - lat2)*x
    theta = (lon2 - lon1)*x
    distance = math.acos(math.cos(a)*math.cos(b)) + (math.sin(a) * math.sin(b) * math.cos(theta))
    return radius * distance
    def f_compute(a, x, N):
    s = 0
    dx = (x - a)/N
    for i in range(N):
    s += ((a + i * dx) ** 2 - (a + i * dx))
    return s * dx
  • 性能测试
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    # test.py
    #!/usr/bin/env python
    # coding=utf-8
    import compute
    import time
    lon1, lat1, lon2, lat2 = -72.345, 34.323, -61.823, 54.826
    start_time = time.clock()
    compute.f_compute(3.2, 6.9, 1000000)
    end_time = time.clock()
    print "runing1 time: %f s" % (end_time - start_time)
    start_time = time.clock()
    for i in range(1000000):
    compute.spherical_distance(lon1, lat1, lon2, lat2)
    end_time = time.clock()
    print "runing2 time: %f s" % (end_time - start_time)

使用案例比较说明 Cython 和 Python 在使用 numpy 时的性能比较?

  • 原始 python 代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # dot_python.py
    import numpy as np
    def naive_dot(a, b):
    if a.shape[1] != b.shape[0]:
    raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
    for j in xrange(m):
    s = 0
    for k in xrange(p):
    s += a[i, k] * b[k, j]
    c[i, j] = s
    return c
  • .pyx 代码
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    # dot_cython.pyx
    import numpy as np
    cimport numpy as np
    cimport cython
    @cython.boundscheck(False)
    @cython.wraparound(False)
    cdef np.ndarray[np.float32_t, ndim=2] _naive_dot(np.ndarray[np.float32_t, ndim=2] a, np.ndarray[np.float32_t, ndim=2] b):
    cdef np.ndarray[np.float32_t, ndim=2] c
    cdef int n, p, m
    cdef np.float32_t s
    if a.shape[1] != b.shape[0]:
    raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
    for j in xrange(m):
    s = 0
    for k in xrange(p):
    s += a[i, k] * b[k, j]
    c[i, j] = s
    return c
    def naive_dot(a, b):
    return _naive_dot(a, b)
  • 性能比较
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    In [4]: a = np.random.randn(30, 50).astype(np.float32)
    In [5]: b = np.random.randn(50, 20).astype(np.float32)
    In [6]: %timeit -n 100 -r 3 dot_python.naive_dot(a, b)
    13.9 ms ± 44.6 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)
    In [7]: %timeit -n 100 -r 3 dot_cython.naive_dot(a, b)
    35.3 µs ± 11.3 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)
    In [8]: %timeit -n 100 -r 3 np.dot(a, b)
    The slowest run took 23.35 times longer than the fastest. This could mean that an intermediate result is being cached.
    135 µs ± 169 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)
    In [9]: %timeit -n 100 -r 3 np.dot(a, b)
    17.2 µs ± 4.48 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)