Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions near/datasets/refine_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,29 @@

class AlanDataset(Dataset):

def __init__(self, root='../../data/Alan', appearance_path='appearance', shape_path='shape', resolution=128, n_samples=None):
def __init__(self, root='../../data/Alan', appearance_path='appearance', shape_path='shape', resolution=128, n_samples=None, filter_low_quality=True):

self.root = root
self.resolution = resolution
self.appearance_dir = os.path.join(root, appearance_path)
self.shape_dir = os.path.join(root, shape_path)

df = pd.read_csv(os.path.join(self.root, 'info.csv'))
info = df[df['low_quality'].isnull()]

# Handle optional 'low_quality' column
if filter_low_quality and 'low_quality' in df.columns:
info = df[df['low_quality'].isnull()]
else:
if filter_low_quality and 'low_quality' not in df.columns:
print("Warning: 'low_quality' column not found in info.csv. Using all samples.")
info = df

# Check for required columns
required_cols = ['ROI_id', 'ROI_anomaly']
missing_cols = [col for col in required_cols if col not in info.columns]
if missing_cols:
raise KeyError(f"Required columns {missing_cols} not found in info.csv. Available columns: {list(df.columns)}")

self.info = info[['ROI_id', 'ROI_anomaly']]
self.info.reset_index(drop=True, inplace=True)

Expand Down