第一章 预备知识(参考答案)

Ex1:利用列表推导式写矩阵乘法

In [4]: M1 = np.random.rand(2,3)
In [5]: M2 = np.random.rand(3,4)
In [6]: res = [[sum([M1[i][k] * M2[k][j] for k in range(M1.shape[1])]) for j in range(M2.shape[1])] for i in range(M1.shape[0])]
In [7]: (np.abs((M1@M2 - res) < 1e-15)).all()
Out[7]: True

Ex2:更新矩阵

In [8]: A = np.arange(1,10).reshape(3,-1)
In [9]: B = A*(1/A).sum(1).reshape(-1,1)
In [10]: B
Out[10]: 
array([[1.83333333, 3.66666667, 5.5       ],
       [2.46666667, 3.08333333, 3.7       ],
       [2.65277778, 3.03174603, 3.41071429]])

Ex3:卡方统计量

In [11]: np.random.seed(0)
In [12]: A = np.random.randint(10, 20, (8, 5))
In [13]: B = A.sum(0)*A.sum(1).reshape(-1, 1)/A.sum()
In [14]: res = ((A-B)**2/B).sum()
In [15]: res
Out[15]: 11.842696601945802

Ex4:改进矩阵计算的性能

原方法:

In [16]: np.random.seed(0)
In [17]: m, n, p = 100, 80, 50
In [18]: B = np.random.randint(0, 2, (m, p))
In [19]: U = np.random.randint(0, 2, (p, n))
In [20]: Z = np.random.randint(0, 2, (m, n))
In [21]: def solution(B=B, U=U, Z=Z):
   ....:     L_res = []
   ....:     for i in range(m):
   ....:         for j in range(n):
   ....:             norm_value = ((B[i]-U[:,j])**2).sum()
   ....:             L_res.append(norm_value*Z[i][j])
   ....:     return sum(L_res)
   ....: 
In [22]: solution(B, U, Z)
Out[22]: 100566

改进方法:

从上式可以看出,第一第二项分别为 B 的行平方和与 U 的列平方和,第三项是两倍的内积。因此, Y 矩阵可以写为三个部分,第一个部分是 m×n 的全 1 矩阵每行乘以 B 对应行的行平方和,第二个部分是相同大小的全 1 矩阵每列乘以 U 对应列的列平方和,第三个部分恰为 B 矩阵与 U 矩阵乘积的两倍。从而结果如下:

In [23]: (((B**2).sum(1).reshape(-1,1) + (U**2).sum(0) - 2*B@U)*Z).sum()
Out[23]: 100566

对比它们的性能:

In [24]: %timeit -n 30 solution(B, U, Z)
43.8 ms +- 1.29 ms per loop (mean +- std. dev. of 7 runs, 30 loops each)
In [25]: %timeit -n 30 ((np.ones((m,n))*(B**2).sum(1).reshape(-1,1) +\
   ....:                   np.ones((m,n))*(U**2).sum(0) - 2*B@U)*Z).sum()
   ....: 
602 us +- 15.6 us per loop (mean +- std. dev. of 7 runs, 30 loops each)

Ex5:连续整数的最大长度

In [26]: f = lambda x:np.diff(np.nonzero(np.r_[1,np.diff(x)!=1,1])).max()
In [27]: f([1,2,5,6,7])
Out[27]: 3
In [28]: f([3,2,1,2,3,4,6])
Out[28]: 4