File tree Expand file tree Collapse file tree
PyTorch/Classification/ConvNets/image_classification Expand file tree Collapse file tree Original file line number Diff line number Diff line change 3434import torchvision .transforms as transforms
3535from PIL import Image
3636from functools import partial
37+ from torchvision .transforms .functional import InterpolationMode
3738
3839from image_classification .autoaugment import AutoaugmentImageNetPolicy
3940
@@ -422,9 +423,10 @@ def get_pytorch_train_loader(
422423 prefetch_factor = 2 ,
423424 memory_format = torch .contiguous_format ,
424425):
425- interpolation = {"bicubic" : Image .BICUBIC , "bilinear" : Image .BILINEAR }[
426- interpolation
427- ]
426+ interpolation = {
427+ "bicubic" : InterpolationMode .BICUBIC ,
428+ "bilinear" : InterpolationMode .BILINEAR ,
429+ }[interpolation ]
428430 traindir = os .path .join (data_path , "train" )
429431 transforms_list = [
430432 transforms .RandomResizedCrop (image_size , interpolation = interpolation ),
@@ -474,9 +476,10 @@ def get_pytorch_val_loader(
474476 memory_format = torch .contiguous_format ,
475477 prefetch_factor = 2 ,
476478):
477- interpolation = {"bicubic" : Image .BICUBIC , "bilinear" : Image .BILINEAR }[
478- interpolation
479- ]
479+ interpolation = {
480+ "bicubic" : InterpolationMode .BICUBIC ,
481+ "bilinear" : InterpolationMode .BILINEAR ,
482+ }[interpolation ]
480483 valdir = os .path .join (data_path , "val" )
481484 val_dataset = datasets .ImageFolder (
482485 valdir ,
You can’t perform that action at this time.
0 commit comments