Closes #16551. Cleanup pickle.py.
This commit is contained in:
parent
c8fb047d69
commit
a3e32c92cf
223
Lib/pickle.py
223
Lib/pickle.py
@ -26,9 +26,10 @@ Misc variables:
|
||||
from types import FunctionType, BuiltinFunctionType
|
||||
from copyreg import dispatch_table
|
||||
from copyreg import _extension_registry, _inverted_registry, _extension_cache
|
||||
import marshal
|
||||
from itertools import islice
|
||||
import sys
|
||||
import struct
|
||||
from sys import maxsize
|
||||
from struct import pack, unpack
|
||||
import re
|
||||
import io
|
||||
import codecs
|
||||
@ -58,11 +59,6 @@ HIGHEST_PROTOCOL = 3
|
||||
# there are too many issues with that.
|
||||
DEFAULT_PROTOCOL = 3
|
||||
|
||||
# Why use struct.pack() for pickling but marshal.loads() for
|
||||
# unpickling? struct.pack() is 40% faster than marshal.dumps(), but
|
||||
# marshal.loads() is twice as fast as struct.unpack()!
|
||||
mloads = marshal.loads
|
||||
|
||||
class PickleError(Exception):
|
||||
"""A common base class for the other pickling exceptions."""
|
||||
pass
|
||||
@ -231,7 +227,7 @@ class _Pickler:
|
||||
raise PicklingError("Pickler.__init__() was not called by "
|
||||
"%s.__init__()" % (self.__class__.__name__,))
|
||||
if self.proto >= 2:
|
||||
self.write(PROTO + bytes([self.proto]))
|
||||
self.write(PROTO + pack("<B", self.proto))
|
||||
self.save(obj)
|
||||
self.write(STOP)
|
||||
|
||||
@ -258,20 +254,20 @@ class _Pickler:
|
||||
self.memo[id(obj)] = memo_len, obj
|
||||
|
||||
# Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
|
||||
def put(self, i, pack=struct.pack):
|
||||
def put(self, i):
|
||||
if self.bin:
|
||||
if i < 256:
|
||||
return BINPUT + bytes([i])
|
||||
return BINPUT + pack("<B", i)
|
||||
else:
|
||||
return LONG_BINPUT + pack("<I", i)
|
||||
|
||||
return PUT + repr(i).encode("ascii") + b'\n'
|
||||
|
||||
# Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
|
||||
def get(self, i, pack=struct.pack):
|
||||
def get(self, i):
|
||||
if self.bin:
|
||||
if i < 256:
|
||||
return BINGET + bytes([i])
|
||||
return BINGET + pack("<B", i)
|
||||
else:
|
||||
return LONG_BINGET + pack("<I", i)
|
||||
|
||||
@ -286,20 +282,20 @@ class _Pickler:
|
||||
|
||||
# Check the memo
|
||||
x = self.memo.get(id(obj))
|
||||
if x:
|
||||
if x is not None:
|
||||
self.write(self.get(x[0]))
|
||||
return
|
||||
|
||||
# Check the type dispatch table
|
||||
t = type(obj)
|
||||
f = self.dispatch.get(t)
|
||||
if f:
|
||||
if f is not None:
|
||||
f(self, obj) # Call unbound method with explicit self
|
||||
return
|
||||
|
||||
# Check private dispatch table if any, or else copyreg.dispatch_table
|
||||
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
|
||||
if reduce:
|
||||
if reduce is not None:
|
||||
rv = reduce(obj)
|
||||
else:
|
||||
# Check for a class with a custom metaclass; treat as regular class
|
||||
@ -313,11 +309,11 @@ class _Pickler:
|
||||
|
||||
# Check for a __reduce_ex__ method, fall back to __reduce__
|
||||
reduce = getattr(obj, "__reduce_ex__", None)
|
||||
if reduce:
|
||||
if reduce is not None:
|
||||
rv = reduce(self.proto)
|
||||
else:
|
||||
reduce = getattr(obj, "__reduce__", None)
|
||||
if reduce:
|
||||
if reduce is not None:
|
||||
rv = reduce()
|
||||
else:
|
||||
raise PicklingError("Can't pickle %r object: %r" %
|
||||
@ -448,12 +444,12 @@ class _Pickler:
|
||||
|
||||
def save_bool(self, obj):
|
||||
if self.proto >= 2:
|
||||
self.write(obj and NEWTRUE or NEWFALSE)
|
||||
self.write(NEWTRUE if obj else NEWFALSE)
|
||||
else:
|
||||
self.write(obj and TRUE or FALSE)
|
||||
self.write(TRUE if obj else FALSE)
|
||||
dispatch[bool] = save_bool
|
||||
|
||||
def save_long(self, obj, pack=struct.pack):
|
||||
def save_long(self, obj):
|
||||
if self.bin:
|
||||
# If the int is small enough to fit in a signed 4-byte 2's-comp
|
||||
# format, we can store it more efficiently than the general
|
||||
@ -461,39 +457,36 @@ class _Pickler:
|
||||
# First one- and two-byte unsigned ints:
|
||||
if obj >= 0:
|
||||
if obj <= 0xff:
|
||||
self.write(BININT1 + bytes([obj]))
|
||||
self.write(BININT1 + pack("<B", obj))
|
||||
return
|
||||
if obj <= 0xffff:
|
||||
self.write(BININT2 + bytes([obj&0xff, obj>>8]))
|
||||
self.write(BININT2 + pack("<H", obj))
|
||||
return
|
||||
# Next check for 4-byte signed ints:
|
||||
high_bits = obj >> 31 # note that Python shift sign-extends
|
||||
if high_bits == 0 or high_bits == -1:
|
||||
# All high bits are copies of bit 2**31, so the value
|
||||
# fits in a 4-byte signed int.
|
||||
if -0x80000000 <= obj <= 0x7fffffff:
|
||||
self.write(BININT + pack("<i", obj))
|
||||
return
|
||||
if self.proto >= 2:
|
||||
encoded = encode_long(obj)
|
||||
n = len(encoded)
|
||||
if n < 256:
|
||||
self.write(LONG1 + bytes([n]) + encoded)
|
||||
self.write(LONG1 + pack("<B", n) + encoded)
|
||||
else:
|
||||
self.write(LONG4 + pack("<i", n) + encoded)
|
||||
return
|
||||
self.write(LONG + repr(obj).encode("ascii") + b'L\n')
|
||||
dispatch[int] = save_long
|
||||
|
||||
def save_float(self, obj, pack=struct.pack):
|
||||
def save_float(self, obj):
|
||||
if self.bin:
|
||||
self.write(BINFLOAT + pack('>d', obj))
|
||||
else:
|
||||
self.write(FLOAT + repr(obj).encode("ascii") + b'\n')
|
||||
dispatch[float] = save_float
|
||||
|
||||
def save_bytes(self, obj, pack=struct.pack):
|
||||
def save_bytes(self, obj):
|
||||
if self.proto < 3:
|
||||
if len(obj) == 0:
|
||||
if not obj: # bytes object is empty
|
||||
self.save_reduce(bytes, (), obj=obj)
|
||||
else:
|
||||
self.save_reduce(codecs.encode,
|
||||
@ -501,13 +494,13 @@ class _Pickler:
|
||||
return
|
||||
n = len(obj)
|
||||
if n < 256:
|
||||
self.write(SHORT_BINBYTES + bytes([n]) + bytes(obj))
|
||||
self.write(SHORT_BINBYTES + pack("<B", n) + obj)
|
||||
else:
|
||||
self.write(BINBYTES + pack("<I", n) + bytes(obj))
|
||||
self.write(BINBYTES + pack("<I", n) + obj)
|
||||
self.memoize(obj)
|
||||
dispatch[bytes] = save_bytes
|
||||
|
||||
def save_str(self, obj, pack=struct.pack):
|
||||
def save_str(self, obj):
|
||||
if self.bin:
|
||||
encoded = obj.encode('utf-8', 'surrogatepass')
|
||||
n = len(encoded)
|
||||
@ -515,39 +508,36 @@ class _Pickler:
|
||||
else:
|
||||
obj = obj.replace("\\", "\\u005c")
|
||||
obj = obj.replace("\n", "\\u000a")
|
||||
self.write(UNICODE + bytes(obj.encode('raw-unicode-escape')) +
|
||||
b'\n')
|
||||
self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n')
|
||||
self.memoize(obj)
|
||||
dispatch[str] = save_str
|
||||
|
||||
def save_tuple(self, obj):
|
||||
write = self.write
|
||||
proto = self.proto
|
||||
|
||||
n = len(obj)
|
||||
if n == 0:
|
||||
if proto:
|
||||
write(EMPTY_TUPLE)
|
||||
if not obj: # tuple is empty
|
||||
if self.bin:
|
||||
self.write(EMPTY_TUPLE)
|
||||
else:
|
||||
write(MARK + TUPLE)
|
||||
self.write(MARK + TUPLE)
|
||||
return
|
||||
|
||||
n = len(obj)
|
||||
save = self.save
|
||||
memo = self.memo
|
||||
if n <= 3 and proto >= 2:
|
||||
if n <= 3 and self.proto >= 2:
|
||||
for element in obj:
|
||||
save(element)
|
||||
# Subtle. Same as in the big comment below.
|
||||
if id(obj) in memo:
|
||||
get = self.get(memo[id(obj)][0])
|
||||
write(POP * n + get)
|
||||
self.write(POP * n + get)
|
||||
else:
|
||||
write(_tuplesize2code[n])
|
||||
self.write(_tuplesize2code[n])
|
||||
self.memoize(obj)
|
||||
return
|
||||
|
||||
# proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple
|
||||
# has more than 3 elements.
|
||||
write = self.write
|
||||
write(MARK)
|
||||
for element in obj:
|
||||
save(element)
|
||||
@ -561,25 +551,23 @@ class _Pickler:
|
||||
# could have been done in the "for element" loop instead, but
|
||||
# recursive tuples are a rare thing.
|
||||
get = self.get(memo[id(obj)][0])
|
||||
if proto:
|
||||
if self.bin:
|
||||
write(POP_MARK + get)
|
||||
else: # proto 0 -- POP_MARK not available
|
||||
write(POP * (n+1) + get)
|
||||
return
|
||||
|
||||
# No recursion.
|
||||
self.write(TUPLE)
|
||||
write(TUPLE)
|
||||
self.memoize(obj)
|
||||
|
||||
dispatch[tuple] = save_tuple
|
||||
|
||||
def save_list(self, obj):
|
||||
write = self.write
|
||||
|
||||
if self.bin:
|
||||
write(EMPTY_LIST)
|
||||
self.write(EMPTY_LIST)
|
||||
else: # proto 0 -- can't use EMPTY_LIST
|
||||
write(MARK + LIST)
|
||||
self.write(MARK + LIST)
|
||||
|
||||
self.memoize(obj)
|
||||
self._batch_appends(obj)
|
||||
@ -599,17 +587,9 @@ class _Pickler:
|
||||
write(APPEND)
|
||||
return
|
||||
|
||||
items = iter(items)
|
||||
r = range(self._BATCHSIZE)
|
||||
while items is not None:
|
||||
tmp = []
|
||||
for i in r:
|
||||
try:
|
||||
x = next(items)
|
||||
tmp.append(x)
|
||||
except StopIteration:
|
||||
items = None
|
||||
break
|
||||
it = iter(items)
|
||||
while True:
|
||||
tmp = list(islice(it, self._BATCHSIZE))
|
||||
n = len(tmp)
|
||||
if n > 1:
|
||||
write(MARK)
|
||||
@ -620,14 +600,14 @@ class _Pickler:
|
||||
save(tmp[0])
|
||||
write(APPEND)
|
||||
# else tmp is empty, and we're done
|
||||
if n < self._BATCHSIZE:
|
||||
return
|
||||
|
||||
def save_dict(self, obj):
|
||||
write = self.write
|
||||
|
||||
if self.bin:
|
||||
write(EMPTY_DICT)
|
||||
self.write(EMPTY_DICT)
|
||||
else: # proto 0 -- can't use EMPTY_DICT
|
||||
write(MARK + DICT)
|
||||
self.write(MARK + DICT)
|
||||
|
||||
self.memoize(obj)
|
||||
self._batch_setitems(obj.items())
|
||||
@ -648,16 +628,9 @@ class _Pickler:
|
||||
write(SETITEM)
|
||||
return
|
||||
|
||||
items = iter(items)
|
||||
r = range(self._BATCHSIZE)
|
||||
while items is not None:
|
||||
tmp = []
|
||||
for i in r:
|
||||
try:
|
||||
tmp.append(next(items))
|
||||
except StopIteration:
|
||||
items = None
|
||||
break
|
||||
it = iter(items)
|
||||
while True:
|
||||
tmp = list(islice(it, self._BATCHSIZE))
|
||||
n = len(tmp)
|
||||
if n > 1:
|
||||
write(MARK)
|
||||
@ -671,8 +644,10 @@ class _Pickler:
|
||||
save(v)
|
||||
write(SETITEM)
|
||||
# else tmp is empty, and we're done
|
||||
if n < self._BATCHSIZE:
|
||||
return
|
||||
|
||||
def save_global(self, obj, name=None, pack=struct.pack):
|
||||
def save_global(self, obj, name=None):
|
||||
write = self.write
|
||||
memo = self.memo
|
||||
|
||||
@ -702,9 +677,9 @@ class _Pickler:
|
||||
if code:
|
||||
assert code > 0
|
||||
if code <= 0xff:
|
||||
write(EXT1 + bytes([code]))
|
||||
write(EXT1 + pack("<B", code))
|
||||
elif code <= 0xffff:
|
||||
write(EXT2 + bytes([code&0xff, code>>8]))
|
||||
write(EXT2 + pack("<H", code))
|
||||
else:
|
||||
write(EXT4 + pack("<i", code))
|
||||
return
|
||||
@ -732,25 +707,6 @@ class _Pickler:
|
||||
dispatch[BuiltinFunctionType] = save_global
|
||||
dispatch[type] = save_global
|
||||
|
||||
# Pickling helpers
|
||||
|
||||
def _keep_alive(x, memo):
|
||||
"""Keeps a reference to the object x in the memo.
|
||||
|
||||
Because we remember objects by their id, we have
|
||||
to assure that possibly temporary objects are kept
|
||||
alive by referencing them.
|
||||
We store a reference at the id of the memo, which should
|
||||
normally not be used unless someone tries to deepcopy
|
||||
the memo itself...
|
||||
"""
|
||||
try:
|
||||
memo[id(memo)].append(x)
|
||||
except KeyError:
|
||||
# aha, this is the first one :-)
|
||||
memo[id(memo)]=[x]
|
||||
|
||||
|
||||
# A cache for whichmodule(), mapping a function object to the name of
|
||||
# the module in which the function was found.
|
||||
|
||||
@ -832,7 +788,7 @@ class _Unpickler:
|
||||
read = self.read
|
||||
dispatch = self.dispatch
|
||||
try:
|
||||
while 1:
|
||||
while True:
|
||||
key = read(1)
|
||||
if not key:
|
||||
raise EOFError
|
||||
@ -862,7 +818,7 @@ class _Unpickler:
|
||||
dispatch = {}
|
||||
|
||||
def load_proto(self):
|
||||
proto = ord(self.read(1))
|
||||
proto = self.read(1)[0]
|
||||
if not 0 <= proto <= HIGHEST_PROTOCOL:
|
||||
raise ValueError("unsupported pickle protocol: %d" % proto)
|
||||
self.proto = proto
|
||||
@ -897,40 +853,37 @@ class _Unpickler:
|
||||
elif data == TRUE[1:]:
|
||||
val = True
|
||||
else:
|
||||
try:
|
||||
val = int(data, 0)
|
||||
except ValueError:
|
||||
val = int(data, 0)
|
||||
val = int(data, 0)
|
||||
self.append(val)
|
||||
dispatch[INT[0]] = load_int
|
||||
|
||||
def load_binint(self):
|
||||
self.append(mloads(b'i' + self.read(4)))
|
||||
self.append(unpack('<i', self.read(4))[0])
|
||||
dispatch[BININT[0]] = load_binint
|
||||
|
||||
def load_binint1(self):
|
||||
self.append(ord(self.read(1)))
|
||||
self.append(self.read(1)[0])
|
||||
dispatch[BININT1[0]] = load_binint1
|
||||
|
||||
def load_binint2(self):
|
||||
self.append(mloads(b'i' + self.read(2) + b'\000\000'))
|
||||
self.append(unpack('<H', self.read(2))[0])
|
||||
dispatch[BININT2[0]] = load_binint2
|
||||
|
||||
def load_long(self):
|
||||
val = self.readline()[:-1].decode("ascii")
|
||||
if val and val[-1] == 'L':
|
||||
val = self.readline()[:-1]
|
||||
if val and val[-1] == b'L'[0]:
|
||||
val = val[:-1]
|
||||
self.append(int(val, 0))
|
||||
dispatch[LONG[0]] = load_long
|
||||
|
||||
def load_long1(self):
|
||||
n = ord(self.read(1))
|
||||
n = self.read(1)[0]
|
||||
data = self.read(n)
|
||||
self.append(decode_long(data))
|
||||
dispatch[LONG1[0]] = load_long1
|
||||
|
||||
def load_long4(self):
|
||||
n = mloads(b'i' + self.read(4))
|
||||
n, = unpack('<i', self.read(4))
|
||||
if n < 0:
|
||||
# Corrupt or hostile pickle -- we never write one like this
|
||||
raise UnpicklingError("LONG pickle has negative byte count")
|
||||
@ -942,28 +895,25 @@ class _Unpickler:
|
||||
self.append(float(self.readline()[:-1]))
|
||||
dispatch[FLOAT[0]] = load_float
|
||||
|
||||
def load_binfloat(self, unpack=struct.unpack):
|
||||
def load_binfloat(self):
|
||||
self.append(unpack('>d', self.read(8))[0])
|
||||
dispatch[BINFLOAT[0]] = load_binfloat
|
||||
|
||||
def load_string(self):
|
||||
orig = self.readline()
|
||||
rep = orig[:-1]
|
||||
for q in (b'"', b"'"): # double or single quote
|
||||
if rep.startswith(q):
|
||||
if not rep.endswith(q):
|
||||
raise ValueError("insecure string pickle")
|
||||
rep = rep[len(q):-len(q)]
|
||||
break
|
||||
# Strip outermost quotes
|
||||
if rep[0] == rep[-1] and rep[0] in b'"\'':
|
||||
rep = rep[1:-1]
|
||||
else:
|
||||
raise ValueError("insecure string pickle: %r" % orig)
|
||||
raise ValueError("insecure string pickle")
|
||||
self.append(codecs.escape_decode(rep)[0]
|
||||
.decode(self.encoding, self.errors))
|
||||
dispatch[STRING[0]] = load_string
|
||||
|
||||
def load_binstring(self):
|
||||
# Deprecated BINSTRING uses signed 32-bit length
|
||||
len = mloads(b'i' + self.read(4))
|
||||
len, = unpack('<i', self.read(4))
|
||||
if len < 0:
|
||||
raise UnpicklingError("BINSTRING pickle has negative byte count")
|
||||
data = self.read(len)
|
||||
@ -971,7 +921,7 @@ class _Unpickler:
|
||||
self.append(value)
|
||||
dispatch[BINSTRING[0]] = load_binstring
|
||||
|
||||
def load_binbytes(self, unpack=struct.unpack, maxsize=sys.maxsize):
|
||||
def load_binbytes(self):
|
||||
len, = unpack('<I', self.read(4))
|
||||
if len > maxsize:
|
||||
raise UnpicklingError("BINBYTES exceeds system's maximum size "
|
||||
@ -983,7 +933,7 @@ class _Unpickler:
|
||||
self.append(str(self.readline()[:-1], 'raw-unicode-escape'))
|
||||
dispatch[UNICODE[0]] = load_unicode
|
||||
|
||||
def load_binunicode(self, unpack=struct.unpack, maxsize=sys.maxsize):
|
||||
def load_binunicode(self):
|
||||
len, = unpack('<I', self.read(4))
|
||||
if len > maxsize:
|
||||
raise UnpicklingError("BINUNICODE exceeds system's maximum size "
|
||||
@ -992,15 +942,15 @@ class _Unpickler:
|
||||
dispatch[BINUNICODE[0]] = load_binunicode
|
||||
|
||||
def load_short_binstring(self):
|
||||
len = ord(self.read(1))
|
||||
data = bytes(self.read(len))
|
||||
len = self.read(1)[0]
|
||||
data = self.read(len)
|
||||
value = str(data, self.encoding, self.errors)
|
||||
self.append(value)
|
||||
dispatch[SHORT_BINSTRING[0]] = load_short_binstring
|
||||
|
||||
def load_short_binbytes(self):
|
||||
len = ord(self.read(1))
|
||||
self.append(bytes(self.read(len)))
|
||||
len = self.read(1)[0]
|
||||
self.append(self.read(len))
|
||||
dispatch[SHORT_BINBYTES[0]] = load_short_binbytes
|
||||
|
||||
def load_tuple(self):
|
||||
@ -1039,12 +989,9 @@ class _Unpickler:
|
||||
|
||||
def load_dict(self):
|
||||
k = self.marker()
|
||||
d = {}
|
||||
items = self.stack[k+1:]
|
||||
for i in range(0, len(items), 2):
|
||||
key = items[i]
|
||||
value = items[i+1]
|
||||
d[key] = value
|
||||
d = {items[i]: items[i+1]
|
||||
for i in range(0, len(items), 2)}
|
||||
self.stack[k:] = [d]
|
||||
dispatch[DICT[0]] = load_dict
|
||||
|
||||
@ -1096,17 +1043,17 @@ class _Unpickler:
|
||||
dispatch[GLOBAL[0]] = load_global
|
||||
|
||||
def load_ext1(self):
|
||||
code = ord(self.read(1))
|
||||
code = self.read(1)[0]
|
||||
self.get_extension(code)
|
||||
dispatch[EXT1[0]] = load_ext1
|
||||
|
||||
def load_ext2(self):
|
||||
code = mloads(b'i' + self.read(2) + b'\000\000')
|
||||
code, = unpack('<H', self.read(2))
|
||||
self.get_extension(code)
|
||||
dispatch[EXT2[0]] = load_ext2
|
||||
|
||||
def load_ext4(self):
|
||||
code = mloads(b'i' + self.read(4))
|
||||
code, = unpack('<i', self.read(4))
|
||||
self.get_extension(code)
|
||||
dispatch[EXT4[0]] = load_ext4
|
||||
|
||||
@ -1174,7 +1121,7 @@ class _Unpickler:
|
||||
self.append(self.memo[i])
|
||||
dispatch[BINGET[0]] = load_binget
|
||||
|
||||
def load_long_binget(self, unpack=struct.unpack):
|
||||
def load_long_binget(self):
|
||||
i, = unpack('<I', self.read(4))
|
||||
self.append(self.memo[i])
|
||||
dispatch[LONG_BINGET[0]] = load_long_binget
|
||||
@ -1193,7 +1140,7 @@ class _Unpickler:
|
||||
self.memo[i] = self.stack[-1]
|
||||
dispatch[BINPUT[0]] = load_binput
|
||||
|
||||
def load_long_binput(self, unpack=struct.unpack, maxsize=sys.maxsize):
|
||||
def load_long_binput(self):
|
||||
i, = unpack('<I', self.read(4))
|
||||
if i > maxsize:
|
||||
raise ValueError("negative LONG_BINPUT argument")
|
||||
@ -1238,7 +1185,7 @@ class _Unpickler:
|
||||
state = stack.pop()
|
||||
inst = stack[-1]
|
||||
setstate = getattr(inst, "__setstate__", None)
|
||||
if setstate:
|
||||
if setstate is not None:
|
||||
setstate(state)
|
||||
return
|
||||
slotstate = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user