#!/usr/bin/env python3# /h/ was hereimportbuiltinsimportioimportpickleimportcollectionsimporttorchimportosimportnumpyimport_codecsdefencode(*args):out=_codecs.encode(*args)print(f'encode({args}) = {out}')returnoutclassRestrictedUnpickler(pickle.Unpickler):defpersistent_load(self,saved_id):assertsaved_id[0]=='storage'returntorch.storage._TypedStorage()deffind_class(self,module,name):print(f'find class {module}{name}')ifmodule=='collections'andname=='OrderedDict':returngetattr(collections,name)ifmodule=='torch._utils'andname=='_rebuild_tensor_v2':returntorch._utils._rebuild_tensor_v2ifmodule=='torch'andnamein['FloatStorage','HalfStorage']:returntorch.FloatStorageifmodule=='numpy.core.multiarray'andname=='scalar':returnnumpy.core.multiarray.scalarifmodule=='numpy'andname=='dtype':returnnumpy.dtypeifmodule=='_codecs'andname=='encode':returnencode# Forbid everything else.raisepickle.UnpicklingError("global '%s/%s' is forbidden"%(module,name))defrestricted_loads(s):"""Helper function analogous to pickle.loads()."""returnRestrictedUnpickler(io.BytesIO(s)).load()# To test that it catches this RCE:# restricted_loads(b"cos\nsystem\n(S'echo hello world'\ntR.")# unzip model.ckpt archive/data.pklwithopen('archive/data.pkl','rb')asf:st=f.read()d=restricted_loads(st)print(dir(d))print(d.keys())print(d['callbacks'])