可持久化对象不如直接多线程dataloader。
想把处理好的对象直接保存下来,然后写了一个Cache类
第一次会把对象保存到硬盘, 第二次会直接读取
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 import osimport torchfrom abc import ABCMeta, abstractmethodclass DataBase (object ): __metaclass__ = ABCMeta def __init__ (self ): pass @abstractmethod def get_data (self ): pass class Cache : def __init__ (self, path ): self.path = path def fetch (self, name, database ): save_path = os.path.join(self.path, name) if os.path.exists(save_path): return torch.load(save_path) else : data = database.get_data() torch.save(data, save_path) return data
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 class TrainDataBase (DataBase ): def get_data (self ): return [ {'image' : data['image' ], 'pose' : data['pose' ]} for i, data in tqdm(enumerate (train_dataloader))] class TestDataBase (DataBase ): def get_data (self ): return [ {'image' : data['image' ], 'pose' : data['pose' ]} for i, data in tqdm(enumerate (test_dataloader)) ] persistent_path = "/SSD1/save_obj" cache = Cache(persistent_path) time1 = time() train_data_vec = cache.fetch("mobile-net-train-data" , TrainDataBase()) time2 = time() test_data_vec = cache.fetch("mobile-net-test-data" , TestDataBase()) time3 = time() print ("%s, %s\n" , time2 - time1, time3 - time1)
然后测试结果是74.46734094619751s, 75.0271668434143s 本来多线程是不到1分钟。。。 去生成的文件瞅瞅,12G 因为Dataloader是多线程处理的,所以会快很多。