diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9eded80..0e63200 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,20 +4,6 @@ on: push jobs: - test_installer: - runs-on: ubuntu-18.04 - steps: - - uses: actions/checkout@v2 - - uses: webfactory/ssh-agent@v0.4.1 - with: - ssh-private-key: ${{ secrets.VEDC_PRIVATE_KEY }} - - name: Set up Python 3.6 - uses: actions/setup-python@v1 - with: - python-version: 3.6 - - name: Run installer - run: python installer/install_ved_capture.py -y -v -b ${GITHUB_REF##*/} --no_ssh --no_version_check - test_app: runs-on: ubuntu-18.04 steps: @@ -30,24 +16,14 @@ jobs: run: | conda --version which python - - name: Extract branch name - shell: bash - run: echo "##[set-output name=branch;]$(echo ${GITHUB_REF##*/})" - id: extract_branch - name: Create environment - env: - VEDC_DEV: true - VEDC_PIN: ${{ steps.extract_branch.outputs.branch }} - run: | - conda install -y -c conda-forge conda-devenv - conda devenv - - name: Copy paths.json run: | - mkdir -p /home/runner/.config/vedc - cp tests/test_data/config/paths.json /home/runner/.config/vedc + conda install -y -c conda-forge mamba + mamba env create - name: Run pytest run: | source activate vedc + mamba install -c conda-forge pytest pytest lint: diff --git a/.gitignore b/.gitignore index 41ebd1f..bf3315d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -environment.yml **/out/ tests/test_data/user_config scratch diff --git a/README.md b/README.md index 8bfc9d0..7ca6fc0 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,33 @@ # ved-capture -**ved-capture** is the app for simultaneous recording of video, gaze and head -tracking data for the Visual Experience Database. +**ved-capture** is the app for simultaneous recording of video, gaze and head tracking data for the Visual Experience Database. ## Installation -The app can be installed on most Linux systems with a single Python script that -can be downloaded [here](https://github.com/vedb/ved-capture/blob/master/installer/install_ved_capture.py) -by right clicking on the "Raw" button and then "Save target as" -or on the [Releases page](https://github.com/vedb/ved-capture/releases). +The following instructions assume that you have installed an Anaconda/Miniconda Python distribution and have set up and SSH keypair with GitHub. If you haven't, please refer to the [installation instructions in the wiki](https://github.com/vedb/ved-capture/wiki/Installation). - $ python3 install_ved_capture.py +Clone the repository: + + $ git clone ssh://git@github.com:vedb/ved-capture + $ cd ved-capture + +Set up the environment: + + $ conda env create + +Configure system: + + $ bash configure.sh + +Reload `.bashrc` (or the appropriate rc file for your shell) and check installation: -The script will guide you through the setup process and instruct you what to -do. Since the app isolates all of its dependencies in a dedicated -environment, the installation has a size of about 3.5 GB, so make sure you -have enough space. + $ source ~/.bashrc + $ vedc check_install ## Usage -The central tool of this app is the command line tool `vedc`. You can use it -to generate recording configurations, make recordings, update the app and more. +The central tool of this app is the command line tool `vedc`. You can use it to generate recording configurations, make recordings and more. ### Generating a configuration @@ -32,8 +38,7 @@ Plug in your hardware (Pupil core system, RealSense T265, FLIR camera) and run: $ vedc auto_config -This will check your connected devices and auto-generate a configuration for -your current setup. +This will check your connected devices and auto-generate a configuration for your current setup. ### Streaming video @@ -43,5 +48,4 @@ To show all camera streams, run: ### Other commands -Check out the [wiki](https://github.com/vedb/ved-capture/wiki) for a -comprehensive list of available commands. +Check out the [wiki](https://github.com/vedb/ved-capture/wiki) for a comprehensive list of available commands. diff --git a/configure.sh b/configure.sh new file mode 100644 index 0000000..9330e8f --- /dev/null +++ b/configure.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# configure Pupil udev rules +sudo usermod -a -G plugdev $USER +echo 'SUBSYSTEM=="usb", ENV{DEVTYPE}=="usb_device", GROUP="plugdev", MODE="0664"' | sudo tee /etc/udev/rules.d/10-libuvc.rules > /dev/null +sudo udevadm trigger + +# configure Spinnaker udev rules +sudo groupadd -f flirimaging +sudo usermod -a -G flirimaging $USER +echo 'SUBSYSTEM=="usb", ATTRS{idVendor}=="1e10", GROUP="flirimaging"' | sudo tee /etc/udev/rules.d/40-flir-spinnaker.rules > /dev/null +sudo /etc/init.d/udev restart + +# increase USBFS memory +sudo sed -i s/GRUB_CMDLINE_LINUX_DEFAULT="quiet splash"/GRUB_CMDLINE_LINUX_DEFAULT="quiet splash usbcore.usbfs_memory_mb=1000"/ /etc/default/grub +sudo update-grub + +# remove old vedc executable +sudo rm -f /usr/local/bin/vedc + +# create alias +CLI_CMD="vedc-cli () { if [[ \$CONDA_DEFAULT_ENV = vedc ]]; then vedc \$@; else conda activate vedc; vedc \$@; conda deactivate; fi }" +ALIAS_CMD="alias vedc=vedc-cli" + +if [[ "$SHELL" == *bash ]] ; then + RC_FILE="$HOME/.bashrc" +elif [[ "$SHELL" == *zsh ]] ; then + RC_FILE="$HOME/.zshrc" +else + echo "Could not determine shell type, add the following commands to your shell's rc file manually:" + echo "${CLI_CMD}" + echo "${ALIAS_CMD}" + exit 1 +fi + +if ! grep -Fxq "${CLI_CMD}" "${RC_FILE}"; then + echo "${CLI_CMD}" >> "${RC_FILE}" +fi + +if ! grep -Fxq "${ALIAS_CMD}" "${RC_FILE}"; then + echo "${ALIAS_CMD}" >> "${RC_FILE}" +fi diff --git a/environment.devenv.yml b/environment.yml similarity index 51% rename from environment.devenv.yml rename to environment.yml index ba1960f..fe11a23 100644 --- a/environment.devenv.yml +++ b/environment.yml @@ -1,22 +1,8 @@ -{{ min_conda_devenv_version("2.1.1") }} - -# set library version -{% set pri_version = "0.4.1" %} - -# version and config dir can be overridden by env vars -{% set pri_pin = os.environ.get("PRI_PIN", pri_version) %} -{% if "VEDCDIR" in os.environ %} -environment: - VEDCDIR: {{ os.environ["VEDCDIR"] }} -{% endif %} - name: vedc - channels: - conda-forge - loopbio - vedb - dependencies: - python=3.6 - pip=20.1 @@ -29,7 +15,6 @@ dependencies: # video - opencv=4.2.0 - ffmpeg=3.4.2=x265_0 - - av=6.2.0 - x264=1!152.20180806 - x265=2.8 # pupil @@ -38,7 +23,8 @@ dependencies: - libjpeg-turbo=2.0.5 - pyuvc=0.14 - msgpack-python=0.6.2 - - pupil-detectors=1.0.5 + - pupil-detectors=1.1.1 + - pupil_recording_interface=0.5.0 # realsense - librealsense=2.42.0 - pyrealsense2=2.42.0.2849 @@ -52,24 +38,10 @@ dependencies: - click=7.1.1 # other - matplotlib=3.1.1 - - gitpython=3.1.0 - tqdm=4.46.0 - simpleaudio=1.0.2 - blessed=1.17.12 - multiprocessing-logging=0.3.1 - # install local editable version of PRI if "PRI_PATH" is set - {% if "PRI_PATH" in os.environ %} - - pip: - - "--editable {{ os.environ['PRI_PATH'] }}" - {% else %} - - pupil_recording_interface={{ pri_pin }} - {% endif %} # install local editable version of ved-capture - pip: - - "--editable {{ root }}" - # development - {% if "VEDC_DEV" in os.environ %} - - pytest - - bump2version - - setproctitle - {% endif %} + - -e . diff --git a/installer/install_ved_capture.py b/installer/install_ved_capture.py index 72eae74..f61d01b 100644 --- a/installer/install_ved_capture.py +++ b/installer/install_ved_capture.py @@ -1,5 +1,7 @@ """ Installation script for ved-capture. +WARNING: THIS INSTALLER IS DEPRECATED AND WILL LIKELY NOT WORK! + This script will set up ved-capture including all of its dependencies as well as the `vedc` command line interface. @@ -28,7 +30,7 @@ import re -__installer_version = "1.4.4" +__installer_version = "1.4.5" # -- LOGGING -- # logger = logging.getLogger(Path(__file__).stem) @@ -70,6 +72,8 @@ def show_welcome_message(yes=False): "# Welcome to the VED capture installation script. #\n" "###################################################\n" "\n" + "WARNING: THIS INSTALLER IS DEPRECATED AND WILL LIKELY NOT WORK!\n" + "\n" "This script will guide you through the setup process for VED " "capture.\n" "\n" @@ -91,6 +95,8 @@ def show_welcome_message(yes=False): "###################################################\n" "# VED capture installation - auto-mode. #\n" "###################################################" + "\n" + "WARNING: THIS INSTALLER IS DEPRECATED AND WILL LIKELY NOT WORK!\n" ) return True diff --git a/installer/test_installer.py b/installer/test_installer.py deleted file mode 100644 index 5b2fb95..0000000 --- a/installer/test_installer.py +++ /dev/null @@ -1,109 +0,0 @@ -import sys -import shutil -from pathlib import Path - -import pytest - -from install_ved_capture import ( - run_command, - check_ssh_pubkey, - get_repo_folder, - get_version_or_branch, - clone_repo, - get_min_conda_devenv_version, - install_miniconda, -) - - -@pytest.fixture(autouse=True) -def setup(): - """""" - # maybe not the best solution - sys.path.append(str(Path.cwd())) - - -@pytest.fixture() -def output_folder(): - """""" - folder = Path.cwd() / "out" - yield folder - shutil.rmtree(folder, ignore_errors=True) - - -@pytest.fixture() -def repo_url(): - """""" - return "ssh://git@github.com/vedb/ved-capture" - - -@pytest.fixture() -def local_repo_folder(): - """""" - return Path(__file__).parents[1] - - -@pytest.fixture() -def repo_folder(repo_url, output_folder): - """""" - folder = output_folder / "ved-capture" - run_command(["git", "clone", "--depth", "1", repo_url, folder]) - run_command( - [ - "git", - f"--work-tree={folder}", - f"--git-dir={folder}/.git", - "fetch", - "--tags", - ] - ) - - return folder - - -class TestMethods: - @pytest.mark.xfail(reason="Fails on GitHub actions") - def test_check_ssh_pubkey(self): - """""" - assert check_ssh_pubkey() is not None - assert check_ssh_pubkey("not_a_key") is None - - def test_get_repo_folder(self, output_folder): - """""" - assert ( - get_repo_folder( - output_folder, "ssh://git@github.com/vedb/ved-capture", - ).stem - == "ved-capture" - ) - - def test_get_version_or_branch(self, repo_folder): - """""" - import re - - pattern = re.compile(r"^v[0-9]+\.[0-9]+\.[0-9]+$") - assert re.match(pattern, get_version_or_branch(repo_folder)) - assert get_version_or_branch(repo_folder, "devel") == "devel" - - def test_clone_repo(self, output_folder, repo_url): - """""" - repo_folder = output_folder / "ved-capture" - clone_repo(output_folder, repo_folder, repo_url) - - assert (output_folder / "ved-capture" / ".git").exists() - - with pytest.raises(SystemExit): - clone_repo( - output_folder, - repo_folder, - "ssh://git@github.com/vedb/wrong_repo", - ) - - def test_install_miniconda(self, output_folder): - """""" - install_miniconda(prefix=output_folder) - assert (output_folder / "bin" / "conda").exists() - - def test_get_min_conda_devenv_version(self, local_repo_folder): - """""" - version = get_min_conda_devenv_version(local_repo_folder) - assert version == "2.1.1" diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index feb2337..0000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest - -from click.testing import CliRunner - -from ved_capture.cli import record, update - - -class TestCli: - @pytest.mark.skip("skip until we figure out how to run this during CI") - def test_record(self, config_dir): - """""" - runner = CliRunner() - result = runner.invoke( - record, f"-v -c {config_dir}/config_minimal.yaml" - ) - - assert result.exit_code == 0 - - @pytest.mark.skip("skip until we figure out how to run this during CI") - def test_update_cli(self): - """""" - runner = CliRunner() - result = runner.invoke(update, "-l -v") - - assert result.exit_code == 0 diff --git a/tests/test_config.py b/tests/test_config.py index ace772a..e6834b1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -202,9 +202,11 @@ def test_get_validation_configs(self, parser): config_list = parser.get_validation_configs() assert config_list[0].stream_type == "video" - assert config_list[0].pipeline[0].process_type == "circle_detector" + assert ( + config_list[0].pipeline[0].process_type == "circle_detector_vedb" + ) assert config_list[0].pipeline[0].min_area == 200 - assert config_list[0].pipeline[1].process_type == "validation" + # TODO assert config_list[0].pipeline[1].process_type == "validation" assert config_list[0].pipeline[2].process_type == "gaze_mapper" assert config_list[0].pipeline[3].process_type == "video_display" diff --git a/tests/test_utils.py b/tests/test_utils.py index ec15e72..527c9b6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,4 @@ -import pytest - -from git.exc import GitError, NoSuchPathError - from ved_capture.utils import ( - get_paths, - update_repo, get_pupil_devices, get_realsense_devices, get_flir_devices, @@ -13,34 +7,6 @@ class TestUtils: - def test_get_paths(self, config_dir): - """""" - paths = get_paths(config_dir) - assert paths == { - "conda_binary": "/usr/share/miniconda/bin/conda", - "conda_script": "/usr/share/miniconda/etc/profile.d/conda.sh", - "vedc_repo_folder": "/home/runner/work/ved-capture/ved-capture", - } - - @pytest.mark.xfail(NoSuchPathError) - def test_update_repo(self, user_config_dir): - """""" - # update once - update_repo(get_paths(user_config_dir)["vedc_repo_folder"]) - - # update again and assert no changes - assert not update_repo(get_paths(user_config_dir)["vedc_repo_folder"]) - - # checkout branch - update_repo(get_paths(user_config_dir)["vedc_repo_folder"], "devel") - assert not update_repo( - get_paths(user_config_dir)["vedc_repo_folder"], "devel" - ) - - # wrong folder - with pytest.raises(GitError): - update_repo("not_a_folder") - def test_get_pupil_devices(self): """""" # TODO check return value diff --git a/ved_capture/cli/__init__.py b/ved_capture/cli/__init__.py index 11e584d..115b4ce 100644 --- a/ved_capture/cli/__init__.py +++ b/ved_capture/cli/__init__.py @@ -35,7 +35,6 @@ def vedc(): vedc.add_command(generate_config) vedc.add_command(auto_config) vedc.add_command(edit_config) -vedc.add_command(update) vedc.add_command(check_install) vedc.add_command(save_logs) vedc.add_command(export) diff --git a/ved_capture/cli/app.py b/ved_capture/cli/app.py index 800c1e1..62acec6 100644 --- a/ved_capture/cli/app.py +++ b/ved_capture/cli/app.py @@ -1,18 +1,68 @@ import importlib import inspect -import sys import traceback import tarfile from pathlib import Path import click -from git import GitError from ved_capture.cli.utils import init_logger, raise_error -from ved_capture.utils import get_paths, update_repo, update_environment from ved_capture.config import ConfigParser +@click.command("check_install") +@click.option( + "-v", "--verbose", default=False, help="Verbose output.", count=True, +) +def check_install(verbose): + """ Test installation. """ + logger = init_logger(inspect.stack()[0][3], verbosity=verbose) + + failures = [] + + def check_import(module): + try: + importlib.import_module(module) + except ImportError: + logger.error(f"Could not import {module}.") + logger.debug(traceback.format_exc()) + failures.append(module) + + for module in ["uvc", "pupil_detectors", "PySpin", "pyrealsense2"]: + check_import(module) + + if len(failures) == 0: + logger.info("Installation check OK.") + else: + raise_error("Installation check failed!", logger) + + +@click.command("save_logs") +@click.option( + "-f", "--filepath", default="~/vedc_logs.tar.gz", help="Output file path.", +) +@click.option( + "-o", + "--overwrite", + default=False, + help="Overwrite existing file.", + is_flag=True, +) +def save_logs(filepath, overwrite): + """ Save logs to a gzipped tar archive. """ + filepath = Path(filepath).expanduser() + source_dir = Path(ConfigParser.config_dir()) + + if filepath.exists() and not overwrite: + raise click.ClickException( + f"{filepath} exists, set -o/--overwrite flag to overwrite" + ) + + with tarfile.open(filepath, "w:gz") as tar: + tar.add(source_dir, arcname=source_dir.stem) + + +### DEPRECATED ### @click.command("update") @click.option( "-v", "--verbose", default=False, help="Verbose output.", count=True, @@ -47,6 +97,10 @@ ) def update(verbose, local, branch, stash, pri_branch, pri_path, force): """ Update installation. """ + import sys + from git import GitError + from ved_capture.utils import get_paths, update_repo, update_environment + logger = init_logger(inspect.stack()[0][3], verbosity=verbose) if pri_branch and pri_path: @@ -95,55 +149,3 @@ def update(verbose, local, branch, stash, pri_branch, pri_path, force): symlink.symlink_to( ConfigParser.config_dir(), target_is_directory=True ) - - -@click.command("check_install") -@click.option( - "-v", "--verbose", default=False, help="Verbose output.", count=True, -) -def check_install(verbose): - """ Test installation. """ - logger = init_logger(inspect.stack()[0][3], verbosity=verbose) - - failures = [] - - def check_import(module): - try: - importlib.import_module(module) - except ImportError: - logger.error(f"Could not import {module}.") - logger.debug(traceback.format_exc()) - failures.append(module) - - for module in ["uvc", "pupil_detectors", "PySpin", "pyrealsense2"]: - check_import(module) - - if len(failures) == 0: - logger.info("Installation check OK.") - else: - raise_error("Installation check failed!", logger) - - -@click.command("save_logs") -@click.option( - "-f", "--filepath", default="~/vedc_logs.tar.gz", help="Output file path.", -) -@click.option( - "-o", - "--overwrite", - default=False, - help="Overwrite existing file.", - is_flag=True, -) -def save_logs(filepath, overwrite): - """ Save logs to a gzipped tar archive. """ - filepath = Path(filepath).expanduser() - source_dir = Path(ConfigParser.config_dir()) - - if filepath.exists() and not overwrite: - raise click.ClickException( - f"{filepath} exists, set -o/--overwrite flag to overwrite" - ) - - with tarfile.open(filepath, "w:gz") as tar: - tar.add(source_dir, arcname=source_dir.stem) diff --git a/ved_capture/cli/config.py b/ved_capture/cli/config.py index 4e31bc5..c67095d 100644 --- a/ved_capture/cli/config.py +++ b/ved_capture/cli/config.py @@ -180,9 +180,12 @@ def auto_config(verbose, test_folder, no_metadata): nested_dict = lambda: defaultdict(nested_dict) # noqa config = nested_dict() - # get version from default config + # load default config with open(Path(__file__).parents[1] / "config_default.yaml") as f: - config["version"] = yaml.safe_load(f)["version"] + default_config = yaml.safe_load(f) + + # set version + config["version"] = default_config["version"] # set test folder if specified if test_folder is not None: @@ -193,6 +196,9 @@ def auto_config(verbose, test_folder, no_metadata): if no_metadata: config["commands"]["record"]["metadata"] = None else: + config["commands"]["record"]["metadata"] = default_config["commands"][ + "record" + ]["metadata"] config["commands"]["record"]["metadata"]["study_site"] = input( "Please enter the study site (UNR, NDSU, Bates, ...): " ) diff --git a/ved_capture/cli/export.py b/ved_capture/cli/export.py index e44590e..b73586a 100644 --- a/ved_capture/cli/export.py +++ b/ved_capture/cli/export.py @@ -1,23 +1,18 @@ -from pathlib import Path import inspect +import json +from pathlib import Path from pprint import pformat import click import pupil_recording_interface as pri -from pupil_recording_interface.externals.file_methods import load_object from ved_capture.cli.utils import init_logger, raise_error @click.command("export") +@click.argument("file-type") @click.argument("folder") @click.argument("topics", nargs=-1) -@click.option( - "-t", - "--file-type", - default="pldata", - help="File type to export (pldata, intrinsics, extrinsics).", -) @click.option( "-f", "--format", default="auto", help="Export format.", ) @@ -27,11 +22,18 @@ def export(folder, topics, file_type, format, verbose): """ Export recording data. + This tool will export all topics that match 'file-type', e.g., + 'vedc export intrinsics' will export all '.intrinsics' files. You can also + specify the topics that you want to export at the end. + + Available file types are 'intrinsics', 'extrinsics' and 'pldata'. + \b Note that supported formats depend on file type: - auto: auto-determine format for file type. - - echo: print export to command line. Not supported for pldata types. + - json: json format. Not supported for pldata types. - nc, netcdf: netCDF4 format. Supported for pldata types. + - echo: print export to command line. Not supported for pldata types. """ logger = init_logger(inspect.stack()[0][3], verbosity=verbose) @@ -45,7 +47,7 @@ def export(folder, topics, file_type, format, verbose): supported_formats = ("netcdf", "nc") supported_topics = ("gaze", "odometry", "accel", "gyro") # TODO pupil else: - supported_formats = ("echo",) + supported_formats = ("json", "echo") supported_topics = None # auto-determine format or check if format is supported for file type @@ -89,14 +91,29 @@ def export(folder, topics, file_type, format, verbose): ) except FileNotFoundError as e: raise_error(str(e), logger) + + elif format.lower() == "json": + export_folder = folder / "exports" + export_folder.mkdir(exist_ok=True) + try: + for topic in topics: + obj = pri.load_object(folder / f"{topic}.{file_type}") + with open( + export_folder / f"{topic}.{file_type}.json", "w" + ) as f: + f.write(json.dumps(obj, indent=4)) + except FileNotFoundError as e: + raise_error(str(e), logger) + elif format.lower() == "echo": try: for topic in topics: - obj = load_object(folder / f"{topic}.{file_type}") + obj = pri.load_object(folder / f"{topic}.{file_type}") filename = f"{topic}.{file_type}" header = f"{filename}\n{'-'*len(filename)}\n" logger.info(f"{header}{pformat(obj)}\n") except FileNotFoundError as e: raise_error(str(e), logger) + else: raise_error(f"Unsupported format: {format}", logger) diff --git a/ved_capture/cli/utils.py b/ved_capture/cli/utils.py index 4d578f2..dc74e8b 100644 --- a/ved_capture/cli/utils.py +++ b/ved_capture/cli/utils.py @@ -198,7 +198,8 @@ def mode_prompt(modes): modes = { idx: mode for idx, mode in enumerate( - pri.VideoDeviceUVC._get_available_modes(uid) + # TODO create a public method for this + pri.VideoDeviceUVC._get_uvc_capture(uid).avaible_modes # [sic] ) } selected_mode = mode_prompt(modes) diff --git a/ved_capture/config.py b/ved_capture/config.py index 63c1f7e..79963a6 100644 --- a/ved_capture/config.py +++ b/ved_capture/config.py @@ -18,6 +18,9 @@ ) import pupil_recording_interface as pri +from ved_capture.process.circle_detector import CircleDetectorVEDB +from ved_capture.process.validation import Validation + APPNAME = "vedc" # maximum width of video windows @@ -315,9 +318,9 @@ def _get_validation_pipeline(self, config, cam_type, name): circle_detector_params = {} config["pipeline"].append( - pri.CircleDetector.Config(**circle_detector_params) + CircleDetectorVEDB.Config(**circle_detector_params) ) - config["pipeline"].append(pri.Validation.Config(save=True)) + config["pipeline"].append(Validation.Config(save=True)) config["pipeline"].append(pri.GazeMapper.Config()) config["pipeline"].append( pri.VideoDisplay.Config(max_width=MAX_WIDTH) @@ -387,6 +390,7 @@ def _get_cam_param_pipeline( if "pipeline" not in config: config["pipeline"] = [] + # circle grid detector try: if self.legacy: detector_params = self.get_command_config( @@ -405,13 +409,27 @@ def _get_cam_param_pipeline( config["pipeline"].append( pri.CircleGridDetector.Config(**detector_params) ) + + # first stream gets cam param estimator if master: - # first stream gets cam param estimator - config["pipeline"].append( - pri.CamParamEstimator.Config( - streams=streams, extrinsics=extrinsics + try: + estimator_params = self.get_command_config( + "estimate_cam_params", + "settings", + name, + "cam_param_estimator", ) + except NotFoundError: + estimator_params = {} + + estimator_params.update( + {"streams": streams, "extrinsics": extrinsics} ) + config["pipeline"].append( + pri.CamParamEstimator.Config(**estimator_params) + ) + + # video display config["pipeline"].append(pri.VideoDisplay.Config(max_width=MAX_WIDTH)) return config diff --git a/ved_capture/config_default.yaml b/ved_capture/config_default.yaml index e0fccbc..55cbee9 100644 --- a/ved_capture/config_default.yaml +++ b/ved_capture/config_default.yaml @@ -94,10 +94,12 @@ commands: subject_id: null age: null gender: null - height: null + height_inches: null ethnicity: null - IPD: null + IPD_cm: null tilt_angle: null + birth_year: null + rig_version: null # Selector to auto-determine a profile at the start of the recording. # The selector has to be defined in the metadata and will use the value @@ -168,7 +170,6 @@ commands: circle_detector: scale: 0.5 paused: true - detection_method: vedb marker_size: - 5 - 300 @@ -193,9 +194,15 @@ commands: circle_grid_detector: stereo: true scale: 0.75 + cam_param_estimator: + num_patterns: 15 + grid_scale: 0.02 world: circle_grid_detector: scale: 0.5 + cam_param_estimator: + num_patterns: 15 + grid_scale: 0.02 ## STREAMS # The streams section defines all streams that are used by the CLI commands. diff --git a/ved_capture/process/__init__.py b/ved_capture/process/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ved_capture/process/circle_detector.py b/ved_capture/process/circle_detector.py new file mode 100644 index 0000000..721a85d --- /dev/null +++ b/ved_capture/process/circle_detector.py @@ -0,0 +1,348 @@ +"""""" +import numpy as np +from numpy import linalg as LA +import cv2 + +from pupil_recording_interface.decorators import process +from pupil_recording_interface.process.circle_detector import CircleDetector +from pupil_recording_interface.externals.methods import normalize + + +@process("circle_detector_vedb") +class CircleDetectorVEDB(CircleDetector): + """ Detector for circular calibration markers. + + This process detects the circular calibration marker used for calibrating + the gaze mapper. Attach this process to the world camera stream. + """ + + def __init__( + self, + scale=0.5, + marker_size=(12, 300), + threshold_window_size=13, + min_area=500, + max_area=1000, + circularity=0.8, + convexity=0.7, + inertia=0.4, + display=True, + **kwargs, + ): + """ Constructor. """ + super().__init__(scale=scale, display=display, **kwargs) + + self.circle_tracker = CircleTrackerVEDB( + scale=scale, + marker_size=marker_size, + threshold_window_size=threshold_window_size, + min_area=min_area, + max_area=max_area, + circularity=circularity, + convexity=convexity, + inertia=inertia, + ) + + +class CircleTrackerVEDB: + def __init__( + self, + wait_interval=30, + roi_wait_interval=120, + scale=0.5, + marker_size=(12, 300), + threshold_window_size=13, + min_area=500, + max_area=1000, + circularity=0.8, + convexity=0.7, + inertia=0.4, + ): + self.wait_interval = wait_interval + self.roi_wait_interval = roi_wait_interval + self._previous_markers = [] + self._predict_motion = [] + self._wait_count = 0 + self._roi_wait_count = 0 + self._flag_check = False + self._flag_check_roi = False + self._world_size = None + self.scale = scale + self._marker_size = marker_size + self.threshold_window_size = threshold_window_size + self.min_area = min_area + self.max_area = max_area + self.circularity = circularity + self.convexity = convexity + self.inertia = inertia + + def update(self, img): + """ + Decide whether to track the marker in the roi or in the whole frame + Return all detected markers + + :param img: input gray image + :type img: numpy.ndarray + :return: all detected markers including the information about their + ellipses, center positions and their type + (Ref/Stop) + :rtype: a list containing dictionary with keys: 'ellipses', 'img_pos', + 'norm_pos', 'marker_type' + """ + img_size = img.shape[::-1] + if self._world_size is None: + self._world_size = img_size + elif self._world_size != img_size: + self._previous_markers = [] + self._predict_motion = [] + self._wait_count = 0 + self._roi_wait_count = 0 + self._world_size = img_size + + if self._wait_count <= 0 or self._roi_wait_count <= 0: + self._flag_check = True + self._flag_check_roi = False + self._wait_count = self.wait_interval + self._roi_wait_count = self.roi_wait_interval + + markers = [] + if self._flag_check: + markers = self._check_frame(img) + predict_motion = [] + if len(markers) > 0: + if len(self._previous_markers) in (0, len(markers)): + self._flag_check = True + self._roi_wait_count -= 1 + for i in range(len(self._previous_markers)): + predict_motion.append( + np.array(markers[i]["img_pos"]) + - np.array(self._previous_markers[i]["img_pos"]) + ) + else: + if self._flag_check_roi: + self._flag_check = True + self._flag_check_roi = False + else: + self._flag_check = False + self._flag_check_roi = False + + self._wait_count -= 1 + self._previous_markers = markers + return markers + + def _check_frame(self, img): + """ + Track the markers in the ROIs / in the whole frame + + :param img: input gray image + :type img: numpy.ndarray + :return: all detected markers including the information about their + ellipses, center positions and their type (Ref/Stop) + :rtype: a list containing dictionary with keys: 'ellipses', 'img_pos', + 'norm_pos', 'marker_type' + """ + img_size = img.shape[::-1] + marker_list = [] + + # Check whole frame + if not self._flag_check_roi: + ellipses_list = self.find_vedb_circle_marker( + img, self.scale, self._marker_size + ) + + # Save the markers in dictionaries + for ellipses_ in ellipses_list: + ellipses = ellipses_["ellipses"] + img_pos = ellipses[0][0] + norm_pos = normalize(img_pos, img_size, flip_y=True) + marker_list.append( + { + "ellipses": ellipses, + "img_pos": img_pos, + "norm_pos": norm_pos, + "marker_type": ellipses_["marker_type"], + } + ) + + # Check roi + else: + for i in range(len(self._previous_markers)): + largest_ellipse = self._previous_markers[i]["ellipses"][-1] + + # Set up the boundary of the roi + if self._predict_motion: + predict_center = ( + largest_ellipse[0][0] + self._predict_motion[i][0], + largest_ellipse[0][1] + self._predict_motion[i][1], + ) + b0 = ( + predict_center[0] + - largest_ellipse[1][1] + - abs(self._predict_motion[i][0]) * 2 + ) + b1 = ( + predict_center[0] + + largest_ellipse[1][1] + + abs(self._predict_motion[i][0]) * 2 + ) + b2 = ( + predict_center[1] + - largest_ellipse[1][0] + - abs(self._predict_motion[i][1]) * 2 + ) + b3 = ( + predict_center[1] + + largest_ellipse[1][0] + + abs(self._predict_motion[i][1]) * 2 + ) + else: + predict_center = largest_ellipse[0] + b0 = predict_center[0] - largest_ellipse[1][1] + b1 = predict_center[0] + largest_ellipse[1][1] + b2 = predict_center[1] - largest_ellipse[1][0] + b3 = predict_center[1] + largest_ellipse[1][0] + + b0 = 0 if b0 < 0 else int(b0) + b1 = img_size[0] - 1 if b1 > img_size[0] - 1 else int(b1) + b2 = 0 if b2 < 0 else int(b2) + b3 = img_size[1] - 1 if b3 > img_size[1] - 1 else int(b3) + col_slice = b0, b1 + row_slice = b2, b3 + + ellipses_list = self.find_vedb_circle_marker( + img[slice(*row_slice), slice(*col_slice)], + self.scale, + self._marker_size, + ) + + # Track the marker which was detected last frame; + # To avoid more than one markers are detected in one ROI + if len(ellipses_list): + if len(ellipses_list) == 1: + right_ellipses = ellipses_list[0] + else: + pre_pos = np.array( + ( + self._previous_markers[i]["img_pos"][0] - b0, + self._previous_markers[i]["img_pos"][1] - b2, + ) + ) + temp_dist = [ + LA.norm(e["ellipses"][0][0] - pre_pos) + for e in ellipses_list + ] + right_ellipses = ellipses_list[ + temp_dist.index(min(temp_dist)) + ] + ellipses = [ + ((e[0][0] + b0, e[0][1] + b2), e[1], e[2]) + for e in right_ellipses["ellipses"] + ] + img_pos = ellipses[0][0] + norm_pos = normalize(img_pos, img_size, flip_y=True) + # Save the marker in dictionary + marker_list.append( + { + "ellipses": ellipses, + "img_pos": img_pos, + "norm_pos": norm_pos, + "marker_type": right_ellipses["marker_type"], + } + ) + + return marker_list + + def threshold_frame(self, frame, window_size): + return cv2.adaptiveThreshold( + frame, + 255, + cv2.ADAPTIVE_THRESH_MEAN_C, + cv2.THRESH_BINARY, + window_size, + 2, + ) + + def erode_frame(self, frame, window_size): + kernel = np.ones((window_size, window_size), int) + return cv2.erode(frame, kernel, iterations=1) + + def define_blob_detector(self): + # Todo: Make sure these parameters are passed through constructor + # arguments + + # Set our filtering parameters + # Initialize parameter settiing using cv2.SimpleBlobDetector + params = cv2.SimpleBlobDetector_Params() + + # Set Area filtering parameters + params.filterByArea = True + params.minArea = self.min_area + params.maxArea = self.max_area + + # Set Circularity filtering parameters + params.filterByCircularity = True + params.minCircularity = self.circularity + # params.minCircularity = 0.7 + # Set Convexity filtering parameters + params.filterByConvexity = True + params.minConvexity = self.convexity + # params.minConvexity = 0.7 + # Set inertia filtering parameters + params.filterByInertia = True + params.minInertiaRatio = self.inertia + # params.minInertiaRatio = 0.6 + + # Create a detector with the parameters + return cv2.SimpleBlobDetector_create(params) + + def find_vedb_circle_marker(self, frame, scale, marker_size): + + # Resize the image + image = cv2.resize(frame, None, fx=scale, fy=scale) + ellipses_list = [] + + # Here we set up our opencv blob detecter code + detector = self.define_blob_detector() + + # Perform image thresholding using an adaptive threshold window + window_size = self.threshold_window_size + image = self.threshold_frame(image, window_size) + + # Perform image erosion in order to remove the possible bright points + # inside the marker + window_size = 3 + image = self.erode_frame(image, window_size) + + # Detect blobs using opencv blob detector that we setup earlier in the + # code + keypoints = detector.detect(image) + + # Check if there is any blobs detected or not, if yes then draw it + # using a red color + if len(keypoints) > 0: + + for keypoint in keypoints: + # Todo: Define acceptable range through constructor argument + if ( + marker_size[0] < keypoint.size < marker_size[1] + ): # 15 and 42 + # Todo: Make sure the fields in ellipse are the same as in + # pupil code + # Todo: Make sure whether the opencv y axis needs to be + # negated!! + ellipses_list.append( + { + "ellipses": [ + ( + ( + keypoint.pt[0] * (1 / scale), + keypoint.pt[1] * (1 / scale), + ), + (keypoint.size, keypoint.size), + keypoint.angle, + ), + ], + "marker_type": "Ref", + } + ) + return ellipses_list diff --git a/ved_capture/process/validation.py b/ved_capture/process/validation.py new file mode 100644 index 0000000..e9b41c7 --- /dev/null +++ b/ved_capture/process/validation.py @@ -0,0 +1,126 @@ +"""""" +import logging + +from pupil_recording_interface.decorators import process +from pupil_recording_interface.process.calibration import Calibration +import numpy as np + +logger = logging.getLogger(__name__) + + +# TODO once pri 1.0 is released: +# @process("validation", optional=("resolution",)) +class Validation(Calibration): + """ Validation during runtime class. """ + + def __init__( + self, + resolution, + eye_resolution=None, + mode="2d", + min_confidence=0.8, + left="eye1", + right="eye0", + world="world", + name=None, + folder=None, + save=False, + **kwargs, + ): + """ Constructor. """ + super().__init__( + resolution, + mode=mode, + min_confidence=min_confidence, + left=left, + right=right, + world=world, + name=name, + folder=folder, + save=save, + **kwargs, + ) + self.eye_resolution = eye_resolution + + def plot_markers(self, circle_marker_list, filename): + """ Plot marker coverage. """ + import matplotlib.pyplot as plt + + plt.figure(figsize=(8, 8)) + x = [c["img_pos"][0] for c in circle_marker_list] + y = [c["img_pos"][1] for c in circle_marker_list] + # Note that: y axis in opencv is inverse of matplotlib! + logger.debug(f"plotting {len(x)} marker points") + plt.plot( + x, self.resolution[1] - np.array(y), "or", markersize=10, alpha=0.7 + ) + plt.xlim(0, self.resolution[0]) + plt.ylim(0, self.resolution[1]) + plt.grid(True) + plt.title("Marker Position", fontsize=18) + plt.rc("xtick", labelsize=12) + plt.rc("ytick", labelsize=12) + plt.xlabel("X (pixels)", fontsize=14) + plt.ylabel("Y (pixels)", fontsize=14) + # Todo: Check if we can pass a flag to show the plots + # (currently doesn't work with the thread timers) + # plt.show() + if filename is not None: + figure_file_name = filename.parent / "marker_coverage.png" + plt.savefig(figure_file_name, dpi=200) + logger.info(f"saved marker plot at: {figure_file_name}") + plt.close() + + def plot_pupils(self, pupil_list, filename): + """ Plot pupil coverage. """ + import matplotlib.pyplot as plt + + if self.eye_resolution is None: + res = (1.0, 1.0) + else: + res = self.eye_resolution + plt.figure(figsize=(8, 8)) + x = [p["norm_pos"][0] * res[0] for p in pupil_list if p["id"] == 0] + y = [p["norm_pos"][1] * res[1] for p in pupil_list if p["id"] == 0] + logger.debug(f"plotting {len(x)} right pupil points") + plt.plot(x, y, "*y", markersize=10, alpha=0.7, label="right") + + x = [p["norm_pos"][0] * res[0] for p in pupil_list if p["id"] == 1] + y = [p["norm_pos"][1] * res[1] for p in pupil_list if p["id"] == 1] + logger.debug(f"plotting {len(x)} left pupil points") + plt.plot(x, y, "*g", markersize=10, alpha=0.7, label="left") + + plt.xlim(0, res[0]) + plt.ylim(0, res[1]) + plt.grid(True) + plt.title("Pupil Position", fontsize=18) + plt.rc("xtick", labelsize=12) + plt.rc("ytick", labelsize=12) + if self.eye_resolution is not None: + plt.xlabel("X (pixels)", fontsize=14) + plt.ylabel("Y (pixels)", fontsize=14) + else: + plt.xlabel("X (normalized position)", fontsize=14) + plt.ylabel("Y (normalized position)", fontsize=14) + if filename is not None: + figure_file_name = filename.parent / "pupil_coverage.png" + plt.savefig(figure_file_name, dpi=200) + logger.info(f"saved pupil plot at: {figure_file_name}") + # Todo: Check if we can pass a flag to show the plots + # (currently doesn't work with the thread timers) + # plt.show() + plt.close() + + def calculate_calibration(self): + """ Calculate calibration from collected data. """ + ( + circle_marker_list, + pupil_list, + filename, + ) = super().calculate_calibration() + logger.info("Plotting Coverage for marker and pupil...") + self.plot_markers(circle_marker_list, filename) + self.plot_pupils(pupil_list, filename) + + logger.info("Validation and Calibration Done Successfully!") + return circle_marker_list, pupil_list, filename diff --git a/ved_capture/utils.py b/ved_capture/utils.py index 7b60932..315727a 100644 --- a/ved_capture/utils.py +++ b/ved_capture/utils.py @@ -1,249 +1,24 @@ """""" -import os -import json import logging -import re import shutil -import subprocess import time from pathlib import Path -from select import select import numpy as np import simpleaudio from confuse import NotFoundError from simpleaudio._simpleaudio import SimpleaudioError -from pkg_resources import parse_version -import git import pupil_recording_interface as pri from pupil_recording_interface.externals.file_methods import load_object -from ved_capture.config import ConfigParser logger = logging.getLogger(__name__) -def log_as_warning_or_debug(data): - """ Log message as warning, unless it's known to be a debug message. """ - _suppress_if_startswith = ( - "[sudo] ", - 'Please run using "bash" or "sh"', - "==> WARNING: A newer version of conda exists. <==", - ) - - _suppress_if_endswith = ("is not a symbolic link",) - - _suppress_if_contains = ("Extracting : ",) - - try: - data = data.strip(b"\n").decode("utf-8") - except UnicodeDecodeError: - logger.debug("!!Error decoding process output!!") - return - - if ( - data.startswith(_suppress_if_startswith) - or data.endswith(_suppress_if_endswith) - or any(data.find(s) for s in _suppress_if_contains) - ): - logger.debug(data) - else: - logger.warning(data) - - -def log_as_debug(data): - """ Log message as debug. """ - try: - data = data.rstrip(b"\n").decode("utf-8") - logger.debug(data) - except UnicodeDecodeError: - logger.debug("!!Error decoding process output!!") - - -def run_command(command, shell=False, f_stdout=None, n_bytes=4096): - """ Run system command and pipe output to logger. """ - with subprocess.Popen( - command, - stdout=f_stdout or subprocess.PIPE, - stderr=subprocess.PIPE, - shell=shell, - ) as process: - if f_stdout is None: - readable = { - process.stdout.fileno(): log_as_debug, - process.stderr.fileno(): log_as_warning_or_debug, - } - else: - readable = { - process.stderr.fileno(): log_as_warning_or_debug, - } - - while readable: - for fd in select(readable, [], [])[0]: - data = os.read(fd, n_bytes) # read available - if not data: # EOF - del readable[fd] - else: - readable[fd](data) - - return process.wait() - - -def get_paths(config_dir=None): - """ Get dictionary with application paths. """ - config_dir = Path(config_dir or ConfigParser.config_dir()) - try: - with open(config_dir / "paths.json") as f: - return json.loads(f.read()) - except FileNotFoundError: - return None - - -def write_paths(paths, config_dir=None): - """ Write dictionary with application paths. """ - config_dir = Path(config_dir or ConfigParser.config_dir()) - with open(config_dir / "paths.json", "w") as f: - f.write(json.dumps(paths)) - - -def update_repo(repo_folder, branch=None, stash=False): - """ Update repository. """ - repo = git.Repo(repo_folder) - current_hash = repo.head.object.hexsha - - # fetch updates - repo.remotes.origin.fetch() - if stash: - repo.git.stash() - - # get latest version or specified branch - branch = ( - branch or sorted(repo.tags, key=lambda t: parse_version(t.name))[-1] - ) - repo.git.checkout(branch) - logger.info(f"Checked out {branch}") - if not repo.head.is_detached: - repo.git.merge() - - # Return True if the repo was updated - return current_hash != repo.head.object.hexsha - - -def get_min_conda_devenv_version(devenv_file): - """ Get minimum conda devenv version. """ - with open(devenv_file) as f: - for line in f: - pattern = re.compile(r'{{ min_conda_devenv_version\("(.+)"\) }}') - result = re.search(pattern, line) - if result: - return result.group(1) - else: - return "2.1.1" - - -def update_environment( - paths, - devenv_file="environment.devenv.yml", - local=False, - pri_branch=None, - pri_path=None, -): - """ Update conda environment. """ - conda_prefix = Path(paths["conda_binary"]).parents[1] - devenv_file = Path(paths["vedc_repo_folder"]) / devenv_file - env_file = Path(paths["vedc_repo_folder"]) / "environment.yml" - if not devenv_file.exists(): - devenv_file = env_file - - if local: - os.environ["VEDC_DEV"] = "" - if pri_branch: - os.environ["PRI_PIN"] = pri_branch - if pri_path: - paths["pri_path"] = str(Path(pri_path).expanduser().resolve()) - write_paths(paths) - - # Only install PRI from local repo if pri_branch isn't set but local or - # pri_path is - if not pri_branch and "pri_path" in paths and (local or pri_path): - os.environ["PRI_PATH"] = paths["pri_path"] - - # TODO hacky way of installing to base environment - os.environ["CONDA_PREFIX"] = str(conda_prefix) - - # Install mamba if missing - if "mamba_binary" not in paths or not Path(paths["mamba_binary"]).exists(): - logger.info("Installing mamba. 🐍") - run_command( - [ - paths["conda_binary"], - "install", - "-y", - "-c", - "conda-forge", - "-n", - "base", - "mamba", - ] - ) - paths["mamba_binary"] = str(conda_prefix / "condabin" / "mamba") - write_paths(paths) - - # Update conda devenv - return_code = run_command( - [ - paths["mamba_binary"], - "install", - "-y", - "-c", - "conda-forge", - "-n", - "base", - f"conda-devenv>={get_min_conda_devenv_version(devenv_file)}", - ] - ) - if return_code != 0: - return return_code - - if ( - "VEDCDIR" in os.environ - and Path(os.environ["VEDCDIR"]) != Path("~/.config/vedc").expanduser() - ): - # Can't use mamba yet if config folder is not in default location - return run_command( - [paths["conda_binary"], "devenv", "-f", devenv_file] - ) - else: - # Update environment.yml with conda devenv - try: - env_file.unlink() - except FileNotFoundError: - pass - with open(env_file, "w") as f: - return_code = run_command( - [ - paths["conda_binary"], - "devenv", - "-f", - devenv_file, - "--print", - ], - f_stdout=f, - ) - if return_code != 0: - return return_code - - # Update environment with mamba - return_code = run_command( - [paths["mamba_binary"], "env", "update", "-f", str(env_file)] - ) - - return return_code - - def get_pupil_devices(): """ Get names and UIDs of connected Pupil cameras. """ + # TODO create a public member for this connected_devices = pri.VideoDeviceUVC._get_connected_device_uids() pupil_cams = { name: uid @@ -462,3 +237,239 @@ def check_disk_space(folder, min_space_gb=30): f"Available disk space in {folder} is {free_gb:.1f} GB, make sure " f"you have at least {min_space_gb} GB of free space" ) + + +### DEPRECATED ### +def log_as_warning_or_debug(data): + """ Log message as warning, unless it's known to be a debug message. """ + _suppress_if_startswith = ( + "[sudo] ", + 'Please run using "bash" or "sh"', + "==> WARNING: A newer version of conda exists. <==", + ) + + _suppress_if_endswith = ("is not a symbolic link",) + + _suppress_if_contains = ("Extracting : ",) + + try: + data = data.strip(b"\n").decode("utf-8") + except UnicodeDecodeError: + logger.debug("!!Error decoding process output!!") + return + + if ( + data.startswith(_suppress_if_startswith) + or data.endswith(_suppress_if_endswith) + or any(data.find(s) for s in _suppress_if_contains) + ): + logger.debug(data) + else: + logger.warning(data) + + +def log_as_debug(data): + """ Log message as debug. """ + try: + data = data.rstrip(b"\n").decode("utf-8") + logger.debug(data) + except UnicodeDecodeError: + logger.debug("!!Error decoding process output!!") + + +def run_command(command, shell=False, f_stdout=None, n_bytes=4096): + """ Run system command and pipe output to logger. """ + import os + import subprocess + from select import select + + with subprocess.Popen( + command, + stdout=f_stdout or subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell, + ) as process: + if f_stdout is None: + readable = { + process.stdout.fileno(): log_as_debug, + process.stderr.fileno(): log_as_warning_or_debug, + } + else: + readable = { + process.stderr.fileno(): log_as_warning_or_debug, + } + + while readable: + for fd in select(readable, [], [])[0]: + data = os.read(fd, n_bytes) # read available + if not data: # EOF + del readable[fd] + else: + readable[fd](data) + + return process.wait() + + +def get_paths(config_dir=None): + """ Get dictionary with application paths. """ + import json + from ved_capture.config import ConfigParser + + config_dir = Path(config_dir or ConfigParser.config_dir()) + try: + with open(config_dir / "paths.json") as f: + return json.loads(f.read()) + except FileNotFoundError: + return None + + +def write_paths(paths, config_dir=None): + """ Write dictionary with application paths. """ + import json + from ved_capture.config import ConfigParser + + config_dir = Path(config_dir or ConfigParser.config_dir()) + with open(config_dir / "paths.json", "w") as f: + f.write(json.dumps(paths)) + + +def update_repo(repo_folder, branch=None, stash=False): + """ Update repository. """ + import git + from pkg_resources import parse_version + + repo = git.Repo(repo_folder) + current_hash = repo.head.object.hexsha + + # fetch updates + repo.remotes.origin.fetch() + if stash: + repo.git.stash() + + # get latest version or specified branch + branch = ( + branch or sorted(repo.tags, key=lambda t: parse_version(t.name))[-1] + ) + repo.git.checkout(branch) + logger.info(f"Checked out {branch}") + if not repo.head.is_detached: + repo.git.merge() + + # Return True if the repo was updated + return current_hash != repo.head.object.hexsha + + +def get_min_conda_devenv_version(devenv_file): + """ Get minimum conda devenv version. """ + import re + + with open(devenv_file) as f: + for line in f: + pattern = re.compile(r'{{ min_conda_devenv_version\("(.+)"\) }}') + result = re.search(pattern, line) + if result: + return result.group(1) + else: + return "2.1.1" + + +def update_environment( + paths, + devenv_file="environment.devenv.yml", + local=False, + pri_branch=None, + pri_path=None, +): + """ Update conda environment. """ + import os + + conda_prefix = Path(paths["conda_binary"]).parents[1] + devenv_file = Path(paths["vedc_repo_folder"]) / devenv_file + env_file = Path(paths["vedc_repo_folder"]) / "environment.yml" + if not devenv_file.exists(): + devenv_file = env_file + + if local: + os.environ["VEDC_DEV"] = "" + if pri_branch: + os.environ["PRI_PIN"] = pri_branch + if pri_path: + paths["pri_path"] = str(Path(pri_path).expanduser().resolve()) + write_paths(paths) + + # Only install PRI from local repo if pri_branch isn't set but local or + # pri_path is + if not pri_branch and "pri_path" in paths and (local or pri_path): + os.environ["PRI_PATH"] = paths["pri_path"] + + # TODO hacky way of installing to base environment + os.environ["CONDA_PREFIX"] = str(conda_prefix) + + # Install mamba if missing + if "mamba_binary" not in paths or not Path(paths["mamba_binary"]).exists(): + logger.info("Installing mamba. 🐍") + run_command( + [ + paths["conda_binary"], + "install", + "-y", + "-c", + "conda-forge", + "-n", + "base", + "mamba", + ] + ) + paths["mamba_binary"] = str(conda_prefix / "condabin" / "mamba") + write_paths(paths) + + # Update conda devenv + return_code = run_command( + [ + paths["mamba_binary"], + "install", + "-y", + "-c", + "conda-forge", + "-n", + "base", + f"conda-devenv>={get_min_conda_devenv_version(devenv_file)}", + ] + ) + if return_code != 0: + return return_code + + if ( + "VEDCDIR" in os.environ + and Path(os.environ["VEDCDIR"]) != Path("~/.config/vedc").expanduser() + ): + # Can't use mamba yet if config folder is not in default location + return run_command( + [paths["conda_binary"], "devenv", "-f", devenv_file] + ) + else: + # Update environment.yml with conda devenv + try: + env_file.unlink() + except FileNotFoundError: + pass + with open(env_file, "w") as f: + return_code = run_command( + [ + paths["conda_binary"], + "devenv", + "-f", + devenv_file, + "--print", + ], + f_stdout=f, + ) + if return_code != 0: + return return_code + + # Update environment with mamba + return_code = run_command( + [paths["mamba_binary"], "env", "update", "-f", str(env_file)] + ) + + return return_code