diff --git a/gget/__init__.py b/gget/__init__.py index 3c65663f..e838b033 100644 --- a/gget/__init__.py +++ b/gget/__init__.py @@ -19,6 +19,7 @@ from .gget_opentargets import opentargets from .gget_cbio import cbio_plot, cbio_search from .gget_bgee import bgee +from .gget_dataverse import dataverse from .gget_8cube import specificity, psi_block, gene_expression from .gget_virus import virus diff --git a/gget/constants.py b/gget/constants.py index 7bf1bdf7..0d125ca0 100644 --- a/gget/constants.py +++ b/gget/constants.py @@ -72,6 +72,9 @@ # OpenTargets API endpoint OPENTARGETS_GRAPHQL_API = "https://api.platform.opentargets.org/api/v4/graphql" +# Harvard dataverse API server +DATAVERSE_GET_URL = "https://dataverse.harvard.edu/api/access/datafile/" + # CBIO data CBIO_CANCER_TYPE_TO_TISSUE_DICTIONARY = { "Acute Leukemias of Ambiguous Lineage": "leukemia", diff --git a/gget/gget_dataverse.py b/gget/gget_dataverse.py new file mode 100644 index 00000000..7d63f58e --- /dev/null +++ b/gget/gget_dataverse.py @@ -0,0 +1,89 @@ +import os +import requests +from tqdm import tqdm +import pandas as pd +import pandas as pd +from .utils import print_sys +from .constants import DATAVERSE_GET_URL + +def dataverse_downloader(url, path, file_name): + """dataverse download helper with progress bar + + Args: + url (str): the url of the dataset to download + path (str): the path to save the dataset locally + file_name (str): the name of the file to save locally + """ + save_path = os.path.join(path, file_name) + response = requests.get(url, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(save_path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + + +def download_wrapper(entry, path, return_type=None): + """wrapper for downloading a dataset given the name and path, for csv,pkl,tsv or similar files + + Args: + entry (dict): the entry of the dataset to download. Must include 'id', 'name', 'type' keys + path (str): the path to save the dataset locally + return_type (str, optional): the return type. Defaults to None. Can be "url", "filename", or ["url", "filename"] + + Returns: + str: the exact dataset query name + """ + url = DATAVERSE_GET_URL + str(entry['id']) + + if not os.path.exists(path): + os.mkdir(path) + + filename = f"{entry['name']}.{entry['type']}" + + if os.path.exists(os.path.join(path, filename)): + print_sys(f"Found local copy for {entry['id']} datafile as {filename} ...") + os.path.join(path, filename) + else: + print_sys(f"Downloading {entry['id']} datafile as {filename} ...") + dataverse_downloader(url, path, filename) + + if return_type == "url": + return url + elif return_type == "filename": + return filename + elif return_type == ["url", "filename"]: + return url, filename + + +def dataverse(df, path, sep=","): + """download datasets from dataverse for a given dataframe + Input dataframe must have 'name', 'id', 'type' columns. + - 'name' is the dataset name for single file + - 'id' is the unique identifier for the file + - 'type' is the file type (e.g. csv, tsv, pkl) + + Args: + df (pd.DataFrame or str): the dataframe or path to the csv/tsv file + path (str): the path to save the dataset locally + """ + if type(df) == str: + if os.path.exists(df): + df = pd.read_csv(df, sep=sep) + else: + raise FileNotFoundError(f"File {df} not found") + elif type(df) == pd.DataFrame: + pass + else: + raise ValueError("Input must be a pandas dataframe or a path to a csv / tsv file") + + print_sys(f"Searching for {len(df)} datafiles in dataverse ...") + + # run the download wrapper for each entry in the dataframe + for _, entry in df.iterrows(): + download_wrapper(entry, path) + + print_sys(f"Download completed, saved to `{path}`.") \ No newline at end of file diff --git a/gget/gget_setup.py b/gget/gget_setup.py index 7a9f5d0a..e197a44f 100644 --- a/gget/gget_setup.py +++ b/gget/gget_setup.py @@ -284,6 +284,7 @@ def setup(module, verbose=True, out=None): # Core AlphaFold dependencies (Colab/CPU friendly set) alphafold_deps = [ "absl-py>=2.1,<3", + "biopython", "dm-haiku<=0.0.12", # dont upgrade to avoid clash with jax "dm-tree>=0.1.8", "filelock>=3.12", diff --git a/gget/main.py b/gget/main.py index 531ae41a..a1386c63 100644 --- a/gget/main.py +++ b/gget/main.py @@ -40,6 +40,7 @@ from .gget_opentargets import opentargets, OPENTARGETS_RESOURCES from .gget_cbio import cbio_plot, cbio_search from .gget_bgee import bgee +from .gget_dataverse import dataverse from .gget_8cube import specificity, psi_block, gene_expression from .gget_virus import virus @@ -2334,6 +2335,32 @@ def main(): help="Does not print progress information.", ) + ## dataverse parser arguments + dataverse_desc = "Download datasets from the Dataverse repositories." + parser_dataverse = parent_subparsers.add_parser( + "dataverse", + parents=[parent], + description=dataverse_desc, + help=dataverse_desc, + add_help=True, + formatter_class=CustomHelpFormatter, + ) + parser_dataverse.add_argument( + "-o", + "--path", + type=str, + required=True, + help="Path to the directory the datasets will be saved in, e.g. 'path/to/directory'.", + ) + parser_dataverse.add_argument( + "-t", + "--table", + type=str, + default=None, + required=False, + help="File containing the dataset IDs to download, e.g. 'datasets.tsv'.", + ) + ## gget 8cube subparser cube_desc = "Query 8cubeDB (https://eightcubedb.onrender.com/)." parser_8cube = parent_subparsers.add_parser( @@ -2834,6 +2861,7 @@ def main(): "opentargets": parser_opentargets, "cbio": parser_cbio, "bgee": parser_bgee, + "dataverse": parser_dataverse, "8cube": parser_8cube, "virus": parser_virus, } @@ -3733,6 +3761,18 @@ def main(): bgee_results.to_json(orient="records", force_ascii=False, indent=4) ) + ## dataverse return + if args.command == "dataverse": + # Define separator based on file extension + if '.csv' in args.table: + sep = ',' + elif '.tsv' in args.table: + sep = '\t' + # Run gget dataverse function + dataverse( + df = args.table, + path = args.out, + sep = sep, ## 8cube return if args.command == "8cube": from .gget_8cube import specificity, psi_block, gene_expression diff --git a/gget/utils.py b/gget/utils.py index 41cee916..9800ebb2 100644 --- a/gget/utils.py +++ b/gget/utils.py @@ -5,6 +5,7 @@ # import time import re import os +import sys import uuid import pandas as pd import numpy as np @@ -66,6 +67,14 @@ def flatten(xss): return [x for xs in xss for x in xs] +def print_sys(s): + """system print + Args: + s (str): the string to print + """ + print(s, flush = True, file = sys.stderr) + + def get_latest_cosmic(): html = requests.get(COSMIC_RELEASE_URL) if html.status_code != 200: diff --git a/tests/test_dataverse.py b/tests/test_dataverse.py new file mode 100644 index 00000000..c1916354 --- /dev/null +++ b/tests/test_dataverse.py @@ -0,0 +1,25 @@ +import unittest +import pandas as pd +from gget.gget_dataverse import dataverse +import os +import shutil + +#TODO: Verify the test code, this is drafted using co-pilot! +class TestDataverse(unittest.TestCase): + def test_dataverse_download(self): + df = pd.DataFrame({ + 'id': [6180617], + 'name': ['nodes'], + 'type': ['tab'] + }) + + dataverse(df, 'temp_datasets') + + # Check if the file is downloaded + self.assertTrue(os.path.exists('temp_datasets/nodes.tab')) + + # Clean up by removing the datasets folder + shutil.rmtree('temp_datasets') + +if __name__ == '__main__': + unittest.main() \ No newline at end of file