|
| 1 | +import os |
| 2 | +from typing import Optional, Union |
| 3 | + |
| 4 | +from DPF.datatypes import ColumnDataType, DataType, FileDataType |
| 5 | +from DPF.modalities import MODALITIES, ModalityName |
| 6 | + |
| 7 | +from .dataset_config import DatasetConfig |
| 8 | + |
| 9 | + |
| 10 | +class FilesDatasetConfig(DatasetConfig): |
| 11 | + """Config for Files dataset type""" |
| 12 | + |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + path: str, |
| 16 | + datatypes: list[Union[FileDataType, ColumnDataType]], |
| 17 | + ): |
| 18 | + """ |
| 19 | + Parameters |
| 20 | + ---------- |
| 21 | + path: str |
| 22 | + Path to dataset metadata file |
| 23 | + datatypes: list[Union[FileDataType, ColumnDataType]] |
| 24 | + List of datatypes in dataset |
| 25 | + """ |
| 26 | + super().__init__(path) |
| 27 | + self.table_path = path |
| 28 | + self.base_path = os.path.dirname(self.table_path) |
| 29 | + self._datatypes = datatypes |
| 30 | + self._modality2datatype = {d.modality.name: d for d in datatypes} |
| 31 | + |
| 32 | + assert len({d.modality.name for d in datatypes}) == len(datatypes), \ |
| 33 | + "More than one datatype with same modality is not supported" |
| 34 | + for data in self.datatypes: |
| 35 | + assert isinstance(data, (ColumnDataType, FileDataType)) |
| 36 | + |
| 37 | + @property |
| 38 | + def datatypes(self) -> list[DataType]: |
| 39 | + return self._datatypes # type: ignore |
| 40 | + |
| 41 | + @property |
| 42 | + def modality2datatype(self) -> dict[ModalityName, DataType]: |
| 43 | + return self._modality2datatype # type: ignore |
| 44 | + |
| 45 | + @property |
| 46 | + def user_column2default_column(self) -> dict[str, str]: |
| 47 | + mapping = {} |
| 48 | + for data in self.datatypes: |
| 49 | + if isinstance(data, ColumnDataType): |
| 50 | + mapping[data.user_column_name] = data.column_name |
| 51 | + elif isinstance(data, FileDataType): |
| 52 | + mapping[data.user_path_column_name] = data.modality.path_column |
| 53 | + return mapping |
| 54 | + |
| 55 | + @classmethod |
| 56 | + def from_path_and_columns( |
| 57 | + cls, |
| 58 | + path: str, |
| 59 | + image_path_col: Optional[str] = None, |
| 60 | + video_path_col: Optional[str] = None, |
| 61 | + text_col: Optional[str] = None, |
| 62 | + ) -> "FilesDatasetConfig": |
| 63 | + """ |
| 64 | + Parameters |
| 65 | + ---------- |
| 66 | + path: str |
| 67 | + Path to dataset metadata file |
| 68 | + image_path_col: Optional[str] = None |
| 69 | + Name of column with image paths |
| 70 | + video_path_col: Optional[str] = None |
| 71 | + Name of column with video paths |
| 72 | + text_col: Optional[str] = None |
| 73 | + Name of column with text |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + FilesDatasetConfig |
| 78 | + Instance of itself |
| 79 | + """ |
| 80 | + datatypes: list[Union[FileDataType, ColumnDataType]] = [] |
| 81 | + if image_path_col: |
| 82 | + datatypes.append(FileDataType(MODALITIES['image'], image_path_col)) |
| 83 | + if video_path_col: |
| 84 | + datatypes.append(FileDataType(MODALITIES['video'], video_path_col)) |
| 85 | + if text_col: |
| 86 | + datatypes.append(ColumnDataType(MODALITIES['text'], text_col)) |
| 87 | + assert len(datatypes) > 0, "At least one modality should be provided" |
| 88 | + return cls(path, datatypes) |
| 89 | + |
| 90 | + def __repr__(self) -> str: |
| 91 | + s = "FilesDatasetConfig(\n\t" |
| 92 | + s += f'table_path="{self.table_path}",\n\t' |
| 93 | + s += 'datatypes=[\n\t\t' |
| 94 | + s += '\n\t\t'.join([str(i) for i in self.datatypes]) |
| 95 | + s += '\n\t]' |
| 96 | + s += '\n)' |
| 97 | + return s |
0 commit comments