本文最后更新于:2023年8月9日 下午
来自于 b 站《PyTorch 深度学习快速入门教程(绝对通俗易懂!)【小土堆】》
本篇内容主要是对代码中的一些小知识点进行归纳总结,以便复习
下面是主要代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| from torch.utils.data import Dataset from PIL import Image import os
class MyData(Dataset): def __init__(self,root_dir,label_dir):
self.root_dir=root_dir self.label_dir=label_dir self.path=os.path.join(self.root_dir,self.label_dir) self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
img_name=self.img_path[idx] img_item_path=os.path.join(self.root_dir,self.label_dir,img_name) img=Image.open(img_item_path) label=self.label_dir return img, label
def __len__(self): return len(self.img_path)
root_dir="dataset/train" ants_label_dir="ants_image" ants_dataset=MyData(root_dir,ants_label_dir) target_dir="ants_image" label=target_dir.split('_')[0] img_path=os.listdir(os.path.join(root_dir,target_dir)) out_dir="ants_label" for i in img_path: file_name=i.split('.jpg')[0] with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f: f.write(label)
|
另外,配环境配了好久。。。