3.2 线性判别函数

3.2.1 线性判别函数

两类问题的判别函数

假设x是二维模式样本x=(x1 x2)Tx=(x_1\ x_2)^T,模式的平面图如下:

此时,分属于ω1\omega_1ω2\omega_2两类的模式可以用一个直线方程进行划分:

d(x)=w1x1+w2x2+w3=0d(x) = w_1x_1 + w_2x_2 + w_3 = 0

其中,x1x_1x2x_2称为坐标变量w1w_1w2w_2w3w_3称为参数方程,则将一个位置模式带入,有:

  • d(x)>0d(x)>0,则xω1x\in \omega_1

  • d(x)<0d(x)<0,则xω2x\in \omega_2

此处的d(x)=0d(x)=0就称为判别函数

用判别函数进行分类的两个因素

  • 判别函数的几何性质

    • 线性的

    • 非线性的

  • 判别函数的系数:使用给定的模式样本确定判别函数的系数

n维线性判别函数

一个n维线性判别函数可以写为:

d(x)=w1x1+w2x2++wnxn+wn+1=w0Tx+wn+1d(x)=w_1x_1 + w_2x_2 + \cdots + w_nx_n + w_{n+1} = w_0^Tx + w_{n+1}

其中,w0=(w1,w2,,wn)T\boldsymbol{w}_0=(w_1,w_2,\dots,w_n)^T称为权向量

此外,d(x)d(x)还可以写为:

d(x)=wTxd(x)=w^Tx

其中,x=(x1,x2,,xn,1)\boldsymbol{x}=(x_1,x_2,\dots,x_n,1)称为增广模式向量w0=(w1,w2,,wn,wn+1)T\boldsymbol{w}_0=(w_1,w_2,\dots,w_n,w_{n+1})^T称为增广权向量

3.2.2 线性判别函数的三种形式

一、ωi\ωi\omega_i\backslash\overline{\omega_i}两分法

每条判别函数只区分是否属于某一类

上图中,白色区域为分类失败区域

  • 将M分类问题分为M个单独的分类问题

  • 需要M条线

  • 不一定能够找到判别函数区分开其它所有类别

二、ωi\ωj\omega_i\backslash\overline{\omega_j}两分法

每一条线分开两种类别

三、类别1转化为类别2

可以通过以下方式将方式1的判别函数转化为方式2的:

d12(x)=d1(x)d2(x)=0d13(x)=d1(x)d3(x)=0d23(x)=d2(x)d3(x)=0d_{12}(x) = d_1(x) - d_2(x) = 0 \\ d_{13}(x) = d_1(x) - d_3(x) = 0 \\ d_{23}(x) = d_2(x) - d_3(x) = 0

四、小结

  • 线性可分:模式分类若可以用任一线性函数划分,则称这些模式是线性可分

  • 一旦线性函数的系数wkw_k被确定,则此函数就可以作为分类的基础

  • 两种分类法的比较

    • 对于M分类,法一需要M个判别函数,法二需要M(M1)2\frac{M(M-1)}{2}个判别函数,因此当模式较多时,法二需要更多的判别函数

    • 但是对于法一而言,并不是每一种情况都是线性可分的,因此法二对模式是线性可分的概率比法一大

绘图代码

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(-5, 10, 100)

y1 = x
y2 = -x + 5
y3 = np.ones_like(x)

plt.plot(x, y1, label='$d_1(x)=-x_1+x_2 = 0$')
plt.plot(x, y2, label='$d_2(x)=x_1+x_2-5 = 0$')
plt.plot(x, y3, label='$d_3(x)=-x_2+1 = 0$')

plt.axhline(y=0, color='black', linestyle='--', linewidth=0.8)
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.axvline(x=0, color='black', linestyle='--', linewidth=0.8)

plt.fill_between(x, y2, np.maximum(y1, y3), where=(x <= 2.5), color='blue', alpha=0.2)
plt.fill_between(x, y1, np.maximum(y2, y3), where=(x >= 2.5), color='orange', alpha=0.2)
plt.fill([-5, 1, 4, 10], [-5, 1, 1, -5], color='green', alpha=0.2)

plt.annotate('$\\omega_1$\n$d_1(x)>0$\n$d_2(x)<0$\n$d_3(x)<0$',
             xy=(-4, 6), xytext=(-4, 4), fontsize=12, color='black')
plt.annotate('$\\omega_2$\n$d_2(x)>0$\n$d_1(x)<0$\n$d_3(x)<0$',
             xy=(7, 6), xytext=(7, 3), fontsize=12, color='black')
plt.annotate('$\\omega_3$\n$d_3(x)>0$\n$d_1(x)<0$\n$d_2(x)<0$',
             xy=(-1, -4), xytext=(1, -4), fontsize=12, color='black')

plt.xlim(-5, 10)
plt.ylim(-5, 10)

plt.legend(loc='lower right')

plt.show()
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(-5, 10, 100)
y = np.linspace(-5, 10, 100)

y1 = x
y2 = -x + 5
x1 = np.full_like(y, 2)

plt.plot(x, y1, label='$d_{23}(x)=-x_1+x_2 = 0$')
plt.plot(x, y2, label='$d_{12}(x)=x_1+x_2-5 = 0$')
plt.plot(x1, y, label='$d_{13}(x)=-x_1+2 = 0$')

plt.axhline(y=0, color='black', linestyle='--', linewidth=0.8)
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.axvline(x=0, color='black', linestyle='--', linewidth=0.8)

plt.fill_between(x, np.maximum(y1, y2), 10 * np.ones_like(x), color='blue', alpha=0.2)
plt.fill_between(x, y2, -5 * np.ones_like(x), where=(x <= 2), color='orange', alpha=0.2)
plt.fill_between(x, y1, -5 * np.ones_like(x), where=(x > 1.9), color='green', alpha=0.2)

plt.annotate('$\\omega_1$\n$d_{12}(x)>0$\n$d_{13}(x)>0$',
             xy=(-3, 2), xytext=(-3, 2), fontsize=12, color='black')
plt.annotate('$\\omega_2$\n$d_{32}(x)>0$\n$d_{12}(x)>0$',
             xy=(1, 6), xytext=(1, 5), fontsize=12, color='black')
plt.annotate('$\\omega_3$\n$d_{32}(x)>0$\n$d_{13}(x)>0$',
             xy=(6, 1), xytext=(6, 1), fontsize=12, color='black')

plt.xlim(-5, 10)
plt.ylim(-5, 10)
plt.grid(True)

plt.legend(loc='lower right')

plt.show()

最后更新于