跳转至

21 继承和多态:完善面向对象编程的两大特征

你好,我是海纳。

封装、继承和多态是面向对象编程的三大特征。从第 18 课开始,我们专注于构建类的定义和对象初始化能力,也就是说,只是完成了封装这一特征,这一节课,我们就会进一步实现方法重载和复写,以及类的继承特征。

第 20 课中我们介绍了操作符的重载,Python 除了可以对操作符进行重载之外,还可以对各种特殊方法进行重载。接下来,我们实现方法的重载。

内建方法重载

Python 里有很多内建方法,比如 len 方法、pow 方法等。len 方法可以支持字符串、列表、字典等类型。如果想让 len 方法也支持自建类型,就必须为自定义类型添加 __len__ 方法的实现,例如:

class Vector(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __add__(self, v):
        return Vector(self.x + v.x, self.y + v.y)

    def __len__(self):
        return self.x * self.x + self.y * self.y

print(len(Vector(3, 4)))
print(len("hello"))

为任意对象增加 len 方法,步骤和增加 add 方法是一样的。

第一步,在 HiObject 类中增加 len 方法,并实现为调用对应 _klass 的 len 方法。

// vm/runtime/interpreter.cpp
_builtins->put(new HiString("len"),      new FunctionObject(object_len));

// vm/runtime/functionObject.cpp
HiObject* object_len(HiList* args, HiDict* kwargs) {
    return args->get(0)->len();
}

// vm/object/hiObject.cpp
HiObject* HiObject::len() {
    return klass()->len(this);
}

第二步,在 Klass 类中尝试调用 __len__ 方法。下面的代码直接使用上节课所定义的 find_and_call 方法就可以很轻松地实现这个功能。

HiObject* Klass::len(HiObject* obj) {
    return find_and_call(obj, nullptr, ST(len));
}

第三步,同步修改内建的字符串类、字典类、列表类的 len 方法。这里只列出了字符串类的实现。字典和列表也实现得也差不多,这里就不再赘述了。

HiObject* StringKlass::len(HiObject* obj) {
    return new HiInteger(obj->as<HiString>()->length());
}

经过了以上修改,前面的测试用例就可以顺利执行了。测试用例的最后,可以正确地打印二维向量的长度。

除了 len 之外,我还列出了其他几个方法,这些方法都会自动对应自定义类型里的特定方法,你可以看一下。

图片

这个表里的 str 函数和 repr 函数的作用很接近。实际上,大多数自定义类,这两个函数的实现也是相同的。但这里也请你注意它们之间的细微区别。

首先,使用 print 函数打印一个对象时,虚拟机默认调用它的 str 方法,如果对象的类型里没有定义 str 方法,就继续调用 repr 方法。如果还没有定义,则打印对象的内存地址。

其次,str 方法更倾向于输出面向用户的信息,可读性尽量高。repr 方法则更倾向于输出面向开发者的信息,要尽可能地打印对象内部信息帮助开发者调试。

所以,我们可以使用对象的 str 方法和 repr 方法来修改 print 方法的逻辑,你可以看一下修改之后的代码。

void Klass::print(HiObject* x) {
    HiObject* func = x->getattr(ST(str));
    if (func != Universe::HiNone) {
        HiObject* s = Interpreter::get_instance()->call_virtual(func, nullptr);
        s->as<HiString>()->print();
        return;
    }

    func = x->getattr(ST(repr));
    if (func != Universe::HiNone) {
        HiObject* r = Interpreter::get_instance()->call_virtual(func, nullptr);
        r->as<HiString>()->print();
        return;
    }

    printf("<object at %p>", x);
}

除了 str 和 repr 之外,表里的大多数函数的实现都很简单,你可以自己实现。

支持特殊操作符

可以重载的操作符,除了上节课所介绍的一元、二元操作符之外,还有一些特殊操作符。比如函数调用、取下标等操作,接下来,我们就继续重载这些特殊的操作符。

函数调用操作符

在 C++ 中,开发者可以重载 () 操作符,这样,就可以把对象当做函数一样调用了。STL 中的 Functor 正是依赖函数调用操作符重载才能够实现。

在 Python 中,开发者也可以通过在类中定义 __call__ 方法来实现 () 操作符的重载,例如:

class Functor(object):
    def __init__(self):
        pass

    def __call__(self):
        print("calling.")

a = Functor()
a()

在上述测试用例中,a 是一个普通对象,但是却可以像函数一样被调用。为了实现这个功能,我们需要实现 Klass 类的 call 方法。

HiObject* Klass::call(HiObject* x, HiList* args, HiDict* kwargs) {
    HiObject* callable = x->getattr(ST(call));

    if (callable == Universe::HiNone) {
        x->print();
        printf(" is non-callable\n");
        assert(false);
    }

    return Interpreter::get_instance()->call_virtual(callable, args);
}

在原有的基础上,添加 call 方法是非常容易的,只需要在目标对象上查找 __call__ 方法并调用它就可以了。

下面我们继续研究取下标操作符。

取下标和取属性操作符

取下标操作符是 [],Python 中也支持对这个操作符进行重载。另外,Python 也支持重载点号操作符,点号在 Python 中用于取属性。

在 Python 中,可以通过在类里定义 __getitem__ 方法来实现 [] 操作符的重载,定义 __getattr__ 方法来实现取属性操作符的重载。我们可以看一个例子。

class A(object):
    def __init__(self, *args):
        self.attrs = dict()
        n = len(args)
        i = 0
        while (i < n):
            a = args[i]
            i += 1
            b = args[i]
            i += 1
            self.attrs[a] = b
            
    def __getitem__(self, key):
        if key in self.attrs:
            return self.attrs[key]
        else:
            return "Error"

    def __setitem__(self, key, value):
        self.attrs[key] = value

    def __delitem__(self, key):
        del self.attrs[key]

    def __str__(self):
        return "object of A"

a = A("hello", "hi", "how are you", "fine")
print(a)
print(a["hello"])
print(a["how are you"])
a["one"] = 1
print(a["one"])
del a["one"]
print(a["one"])
print(a.attrs)

中括号所对应的取下标操作会被翻译成 BINARY_SUBSCR 字节码,这个字节码的实现是调用 HiObject 的 subscr 方法。所以,我们只需要在 Klass 中实现这个方法就可以了。

HiObject* Klass::subscr(HiObject* x, HiObject* y) {
    HiList* args = new HiList();
    args->append(y);
    return find_and_call(x, args, ST(getitem));
}

相应的,如果要对下标进行赋值,所对应的字节码是 STORE_SUBSCR。重载下标赋值操作,需要在类定义中添加 __setitem__ 方法,而删除下标操作,则需要添加 __delitem__ 方法,我们也把这两个方法添加到 Klass 里。

void Klass::store_subscr(HiObject* x, HiObject* y, HiObject* z) {
    HiList* args = new HiList();
    args->append(y);
    args->append(z);
    find_and_call(x, args, ST(setitem));
}

void Klass::del_subscr(HiObject* x, HiObject* y) {
    HiList* args = new HiList();
    args->append(y);
    find_and_call(x, args, ST(delitem));
}

取下标的机制虽然特殊,但究其根源,与数值计算并没有什么区别。接下来我们继续研究取属性操作。不管是自定义类里的 __getattr__ 方法,还是在虚拟机中实现 getattr 都必须十分小心,要避免进入无穷递归中。我们先看一个测试用例。

class B(object):
    def __init__(self):
        self.keys = {}
  
    def __getattr__(self, k): 
        if k in self.keys:
            return self.keys[k]
        else:
            return None
b = B()
print(b.value)

测试代码的最后一行是一个取对象属性的操作,前边我们介绍了,这个操作可以被 __getattr__ 方法重载。因此执行这一行的时候,虚拟机就会转而调用 b 的 __getattr__ 方法。

在这个方法里,又使用了一次取对象属性的操作,查找 self 对象的 keys 属性,这又会产生一次对 __getattr__ 方法的调用。无穷递归调用就产生了,所以虚拟机很快就会崩溃退出。

实际上,不管是取下标操作,还是普通的数值计算,都可能写出无穷递归,只是那种情况相对少一些。而取对象属性的操作太常见,所以写出上面错误代码的可能性就更大一些。

下面我们再看一个正确的示例。

keys = []
values = []

class B(object):
    def __setattr__(self, k, v):
        if k in keys:
            index = keys.index(k)
            values[index] = v
        else:
            keys.append(k)
            values.append(v)

    def __getattr__(self, k):
        if k in keys:
            index = keys.index(k)
            return values[index]
        else:
            return None

b = B()
b.foo = 1
b.bar = 2
print(b.foo)
print(b.bar)
b.foo = 3
print(b.foo)

设置对象属性会被翻译成 STORE_ATTR(第 24 和 25 行),所以我们就在 Klass 中增加调用 __setattr__ 的逻辑。Klass 中已经有 setattr 方法了,只需要在这个方法里增加重载的逻辑即可。

// setattr for normal object.
HiObject* Klass::setattr(HiObject* obj, HiObject* x, HiObject* y) {
    HiObject* func = obj->klass()->klass_dict()->get(ST(setattr));

    // 如果未定义__setattr__方法,就直接放到对象的obj_dict中
    if (func == Universe::HiNone) {
        obj->obj_dict()->put(x, y);
        return Universe::HiNone;
    }

    func = new MethodObject(func->as<FunctionObject>(), obj);
    HiList* args = new HiList();
    args->append(x);
    args->append(y);
    return Interpreter::get_instance()->call_virtual(func, args);
}

上述代码中,setattr 一开始就是判断 x 的类定义中是否有 __setattr__ 方法。如果没有的话,就沿着原来的逻辑继续去对象的属性字典里查找。如果有这个方法,就通过 call_virtual 执行这个方法。

在测试用例里取对象上的 foo 属性(第 26 行),会被翻译为 LOAD_ATTR。我们需要在 Klass 的 getattr 方法中增加重载取属性操作的逻辑。

HiObject* Klass::getattr(HiObject* x, HiObject* y) {
    HiObject* func = x->klass()->klass_dict()->get(ST(getattr));

    // 如果类里定义了__getattr__方法,就先调用这个方法
    if (func != Universe::HiNone) {
        func = new MethodObject(func->as<FunctionObject>(), x);
        HiList* args = new HiList();
        args->append(y);
        return Interpreter::get_instance()->call_virtual(func, args);
    }

    // 如果没有定义 __getattr__方法,优化去对象字典里找
    if (x->obj_dict()->has_key(y)) {
        return x->obj_dict()->get(y);
    }

    // 如果对象字典里也没有,就去类里找
    HiObject* result = _klass_dict->get(y);
    // Only klass attribute needs bind.
    if (!MethodObject::is_method(result) &&
        MethodObject::is_function(result)) {
        result = new MethodObject(result->as<FunctionObject>(), x);
    }

    return result;
}

这段代码结合注释应该不难理解。到这里,取下标操作和取属性操作就都能支持了,编译执行测试用例,就可以打印出正确的值。

重载操作符还有很多知识,比如原位计算、右结合运算等。但根本性的问题都已经介绍完了,更多的方法重载你可以自己实现。

实现类的继承

这部分我们来实现类的继承,先看第一个例子。

class A(object):
    def say(self):
        print("I am A")

class B(A):
    def say(self):
        print("I am B")

class C(A):
    pass

b = B()
c = C()

b.say()    # "I am B"
c.say()    # "I am A"

代码中对象 c 的类型是 C,C 中没有定义 say 方法,虚拟机就会从它的超类中查找,发现 A 中有定义 say 方法,所以第 16 行的运行结果是打印 "I am A"

而对象 b 的类型是 B,B 中定义了自己的 say 方法,虚拟机就会直接执行这个方法,所以第 15 行的运行结果是打印 "I am B"

根据这个例子我们可以得到结论,Python 在调用某个对象的方法时,会先检查它所对应的类型中,是否定义了该方法。如果没有定义,就转而去查找它的父类,如果父类还是没有,就要继续查找它的祖先类,直到找到方法定义为止。

如果 Python 像 Java 语言那样只支持单继承,每个类只能有一个父类,那这个查找方法定义的过程也会非常简单。但实际上,Python 是支持多继承的。如果一个类有多个父类的时候,要想从父类中查找方法定义,应该按照怎样的顺序查找呢?

Python 在父类中查找方法的过程也被叫做方法解析的过程,这个查询父类的顺序就被叫做方法解析顺序(Method Resolution Order,MRO)。

Python 发展到今天经历了 3 种 MRO 算法,分别是:

  1. 从左往右,采用深度优先搜索(DFS)的算法,叫做旧式类的 MRO。
  2. 从 Python 2.2 版本开始,新式类在采用深度优先搜索算法的基础上做了优化,解决菱形继承问题。
  3. Python 2.3 版本对新式类采用了 C3 算法。由于 Python 3.x 仅支持新式类,所以这个版本只使用 C3 算法。

Python 2.7 中所使用的基于深度优先搜索的算法,在我所写的《自己动手写Python虚拟机》中有论述,如果你感兴趣的话,可以去读一读,我们这个课程只关注 C3 算法。

C3 算法

如果 B 类型有 class B(A):pass 的继承关系,就记 B 的mro序列为 [B,A]

如果 B 类型继承自多个基类 class B(A1,A2,...,An): pass,就记 B 的 mro 序列为:

mro(B) = [B] + merge(mro(A1), mro(A2),...,mro(An))

其中,merge 操作用来合并 B 所有父类的 mro 序列。我们来看一下它的算法步骤。

  1. 遍历所有父类的 mro 序列,如果一个序列的第一个元素,在其他序列中也是第一个元素,或不在其他序列出现,那么这个元素就是一个合法的元素。
  2. 如果找不到合法元素,说明类定义不合法,算法结束。
  3. 从所有 mro 序列中删除第一步找到的合法元素,并把它合并到 B 类型的 mro 里。
  4. 所有的非空 mro 序列重新组成一个序列,重复第一步的算法。直到所有 mro 序列都变成空序列。

你可以看一下我给出的几个定义。

class A(object):pass
class B(object):pass
class C(object):pass
class E(A,B):pass
class F(B,C):pass
class G(E,F):pass

A、B、C都继承自一个基类,所以 mro 序列依次为 [A, O]、[B, O]、[C, O]。计算 E 的 mro 序列的步骤是:

mro(E) = [E] + merge(mro(A), mro(B), [A, B])
       = [E] + merge([A, O], [B, O], [A, B]) (1)
       = [E, A] + merge([O], [B,O], [B])     (2)
       = [E, A, B] + merge([O], [O])         (3)
       = [E, A, B, O]

公式(1)处,A是序列 [A, O] 中的第一个元素,在序列 [B, O] 中不出现,在序列 [A, B] 中也是第一个元素,所以从全体序列 ([A, O]、[B, O]、[A, B]) 中删除A,并把 A 合并到当前 mro 中,得到 [E, A]。

公式(2)处,O 是序列 [O] 中的第一个元素,但 O 在序列 [B, O] 中出现并且不是第一个元素,所以 O 不满足条件。继续查看 [B, O] 的第一个元素 B,B 满足条件,所以从全部序列中删除B,并将 B 合并到序列 [E, A] 中,得到 [E, A, B]。

我们再看一个例子。

class X(object): pass
class Y(object): pass
class A(X,Y): pass
class B(Y,X): pass
class C(A, B): pass

给这些类的定义套用 C3 算法,我们看得到的结果。

mro(X) = [X, O]
mro(Y) = [Y, O]
mro(A) = [A] + merge(mro(X), mro(Y))
       = [A] + [X, O] + [Y, O]
       = [A, X] + [O] + [Y, O]
       = [A, X, Y] + [O] + [O]
       = [A, X, Y, O]
mro(B) = [B, Y, X, O]

这个演算过程比较简单,我就不再解释了。接下来,我们重点关注 C 的 mro 序列的计算过程。

mro(C) = [C] + merge(mro(A), mro(B))
       = [C] + merge([A, X, Y, O], [B, Y, X, O])
       = [C, A] + merge([X, Y, O], [B, Y, X, O])
       = [C, A, B] + merge([X, Y, O], [Y, X, O])

演算到这一步的时候,就发生问题了,不管是第一个序列的首元素 X 还是第二个序列的首元素 Y,都不满足条件,所以算法无法使 merge 的各个序列变为空序列,只能报错退出。

由此可见,使用 C3 算法既可以很好地解决菱形继承的问题,也可以检查出继承顺序不一致的情况。接下来,我们就来实现这个算法。

实现 C3 算法

为了实现多继承,我们可以在 Klass 中定义 super 属性,它是一个列表,记录了某个类型的所有父类。同时为每个类型再创建一个属性,记录该类型的方法解析顺序,我们把这个属性起名为 _mro。

// [object/klass.hpp]
class Klass {
private:
    HiList*       _super;
    HiList*       _mro;
    ...
public:
    ...
    void add_super(Klass* x);
    void order_supers();

    HiList* super()                       { return _super; }
    void set_super_list(HiList* x)        { _super = x; }
    HiList* mro()                         { return _mro; }
    ...
};

// [object/klass.cpp]
void Klass::add_super(Klass* klass) {
    if (_super == nullptr)
        _super = new HiList();

    _super->append(klass->type_object());
}

为了实现继承,我们在 Klass 里添加了 5 个方法,除了 order_super 之外的其他方法的实现,逻辑都非常简单,在代码中已经列出来了,你可以看一下。

order_super 方法是 C3 算法的具体实现,用来计算类型的 mro,我们来重点看一下它的代码。

HiList* Klass::linear(HiTypeObject* obj) {
    HiList* result = new HiList();
    HiList* mro = obj->mro();
    for (int i = 0; i < mro->length(); i++) {
        result->append(mro->get(i));
    }

    return result;
}

HiList* Klass::merge(HiList* supers) {
    if (supers->empty()) {
        return new HiList();
    }

    for (int i = 0; i < supers->length(); i++) {
        bool valid = true;
        // head = supers[i][0]
        HiTypeObject* head = supers->get(i)->as<HiList>()->get(0)->as<HiTypeObject>();
        for (int j = 0; j < supers->length(); j++) {
            if (j == i) continue;
            // if head in supers[j][1:]
            if (supers->get(j)->as<HiList>()->index(head) > 0) {
                valid = false;
                break;
            }
        }

        if (!valid) {
            continue;
        }

        HiList* next = new HiList();
        for (int j = 0; j < supers->length(); j++) {
            HiList* item = supers->get(j)->as<HiList>();
            item->remove(head);
            if (!item->empty()) {
                next->append(item);
            }
        }
        HiList* result = merge(next);
        result->insert(0, head);

        return result;
    }

    printf("Should not reach here\n");
    assert(false);
    return nullptr;
}

void Klass::order_supers() {
    if (_super == nullptr) {
        _mro = new HiList();
        _mro->append(_type_object);
        return;
    }

    HiList* all = new HiList();
    for (int i = 0; i < _super->length(); i++) {
        all->append(linear(_super->get(i)->as<HiTypeObject>()));
    }

    _mro = merge(all);
    _mro->insert(0, _type_object);
    _klass_dict->put(ST(mro), _mro);
}

order_supers 的作用是将本类型的所有父类的 mro 序列进行合并。变量 all 里记录了父类的 mro 序列(第 59 至 63 行)。

合并的具体动作是由 merge 方法完成的。merge 通过一个大循环来找到合法元素(第 16 行至 32 行)。如果所有序列都遍历完了,还没有找到合法元素,就会报错退出(第 47 至 49 行)。

如果找到合法元素,就把它从全部序列中删除,将剩下的序列再重新组成一个列表,再次调用 merge 方法进行合并(第 33 至 41 行)。同时,合法元素会被合并至结果列表中(第 42 行)。

整个算法全部结束以后,再把 mro 序列和 __mro__这个字符串关联起来。这样,在程序中就可以通过类型的 __mro__ 属性访问到这个列表了。

最后,我们要在每个类的初始化里完成 mro 的排序工作,以 HiList 类为例,你可以看一下它的代码。

void ListKlass::initialize() {
    HiDict * klass_dict = new HiDict();
    //...
    set_klass_dict(klass_dict);
    (new HiTypeObject())->set_own_klass(this);
    set_name(new HiString("list"));

    add_super(ObjectKlass::get_instance());
    order_supers();
}

而对于用户自定义类型,则应该在 klass 创建的时候构建它的 mro,也就是 create_klass 中,你可以自己添加。

最后,类的多继承对于方法的查找是有影响的,主要是在 getattr 中,如果在本类型中查找不到,还要去自己的 mro 列表中的类型里去查找,所以 getattr 就变成了这个样子。

HiObject* Klass::find_in_mro(HiObject* obj, HiString* name) {
    HiObject* x = obj->klass()->klass_dict()->get(name);

    if (x != Universe::HiNone) {
        return x;
    }

    for (int i = 0; i < obj->klass()->mro()->length(); i++) {
        x = obj->klass()->mro()->get(i)->getattr(name);
        if (x != Universe::HiNone) {
            return x;
        }
    }

    return x;
}

HiObject* Klass::getattr(HiObject* x, HiObject* y) {
    HiObject* func = find_in_mro(x, ST(getattr));

    // 如果类里定义了__getattr__方法,就先调用这个方法
    if (func != Universe::HiNone) {
        func = new MethodObject(func->as<FunctionObject>(), x);
        HiList* args = new HiList();
        args->append(y);
        return Interpreter::get_instance()->call_virtual(func, args);
    }

    // 如果没有定义 __getattr__方法,就去对象字典里找
    if (x->obj_dict()->has_key(y)) {
        return x->obj_dict()->get(y);
    }

    // 如果对象字典里也没有,就去类里找
    HiObject* result = find_in_mro(x, y->as<HiString>());
    // Only klass attribute needs bind.
    if (MethodObject::is_function(result) ||
        MethodObject::is_native(result)) {
        result = new MethodObject(result->as<FunctionObject>(), x);
    }

    return result;
}

经过不断地增加功能,我们最终将 getattr 的功能全部补齐了。

首先,虚拟机在对象所对应的类型中查找是否有 __getattr__ 方法(第 19 行),如果找到了这个方法就直接调用这个方法。如果没有找到,就转而在对象字典里查找是否有 y 属性(第 30 至 32 行)。如果失败,再去对象所对应的类型中查找是否有 y 属性。这个逻辑就比较清晰了。

经过了这些改造以后,我们虚拟机的行为终于和 CPython 3.8 的行为一致了。开始时的例子也能正确执行了。

到此为止,我们的虚拟机中的对象系统就构建得差不多了,这为虚拟机的下一步开发奠定了坚实的基础,还有一些与对象系统相关的功能,比如异常的定义和处理等,我们等模块功能实现了以后再来添加。

总结

这节课我们先实现了内建方法重载,主要以 len、str 和 repr 三个函数为例,说明了如何实现内建方法重载。然后我们介绍了三个比较特殊的操作符。

  1. 第一个是 () 操作符,用于模拟函数调用的,如果一个对象定义了 __call__ 方法,那么这个对象作为 CALL_FUNCTION 指令的操作数时,它的 __call__ 方法就会被执行。
  2. 第二个是下标操作符 [],分别使用 __setitem____getitem____delitem__ 来支持设置、读取和删除操作。
  3. 第三个是属性操作符点号,我们通过 __getattr____setattr__ 方法来说明了如何进行对象属性的操作。
    第三部分我们实现了 Python 的多继承,以及如何确定继承顺序。这部分,我们重点介绍了 C3 算法的原理和实现。

通过这四节课的努力,我们终于将对象系统构建得基本完备了。到目前为止,我们在写代码的时候,需要什么对象都是直接通过 new 操作符来创建的,从来没有释放过,这必然会带来大量的内存泄漏。从下一节课起,我们该把内存好好梳理一遍了。

思考题

多继承机制会如何影响 type、isinstance 等几个和类型相关的函数?应该如何改写才能使它们正确工作呢?欢迎你把你思考后的结果分享到评论区,也欢迎你把这节课的内容分享给其他朋友,我们下节课再见!

精选留言(2)
  • 冯某 👍(0) 💬(0)

    记录一下

    2024-12-09

  • ifelse 👍(0) 💬(0)

    学习打卡

    2024-11-05