Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions camel_tools/disambig/bert/unfactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import json
from pathlib import Path
import pickle
import warnings

from cachetools import LFUCache
import numpy as np
Expand Down Expand Up @@ -289,7 +290,7 @@ def __init__(self, model_path, analyzer,

@staticmethod
def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
cache_size=10000, pretrained_cache=True,
cache_size=10000, pretrained_cache=False,
ranking_cache_size=100000):
"""Load a pre-trained model provided with camel_tools.

Expand All @@ -306,9 +307,11 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
the analyzer will cache the analyses for the cache_size most
frequent words, otherwise no analyses will be cached.
Defaults to 100000.
pretrained_cache (:obj:`bool`, optional): The flag to use a
pretrained cache that stores ranked analyses.
Defaults to True.
pretrained_cache (:obj:`bool`, optional): Deprecated. The flag
used to load a pretrained cache that stores ranked analyses.
We are removing the pretrained cache for now but are keeping
the option here to not break existing code.
Defaults to False.
ranking_cache_size (:obj:`int`, optional): The number of unique
word disambiguations to cache. If 0, no ranked analyses will be
cached. The cache uses a least-frequently-used eviction policy.
Expand All @@ -330,14 +333,13 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,
cache_size=cache_size)
scorer = model_config['scorer']
tie_breaker = model_config['tie_breaker']
ranking_cache = None
if pretrained_cache:
cache_info = CATALOGUE.get_dataset('DisambigRankingCache',
model_config['ranking_cache'])
cache_path = Path(cache_info.path, 'default_cache.pickle')
with open(cache_path, 'rb') as f:
ranking_cache = pickle.load(f)
else:
ranking_cache = None
warnings.warn(
'The `pretrained_cache` argument is deprecated and will be '
'removed in a future release.',
DeprecationWarning,
stacklevel=2)

return BERTUnfactoredDisambiguator(
model_path,
Expand All @@ -353,7 +355,7 @@ def pretrained(model_name='msa', top=1, use_gpu=True, batch_size=32,

@staticmethod
def _pretrained_from_config(config, top=1, use_gpu=True, batch_size=32,
cache_size=10000, pretrained_cache=True,
cache_size=10000, pretrained_cache=False,
ranking_cache_size=100000):
"""Load a pre-trained model from a config file.

Expand All @@ -369,9 +371,11 @@ def _pretrained_from_config(config, top=1, use_gpu=True, batch_size=32,
the analyzer will cache the analyses for the cache_size
most frequent words, otherwise no analyses will be cached.
Defaults to 100000.
pretrained_cache (:obj:`bool`, optional): The flag to use a
pretrained cache that stores ranked analyses.
Defaults to True.
pretrained_cache (:obj:`bool`, optional): Deprecated. The flag
used to load a pretrained cache that stores ranked analyses.
We are removing the pretrained cache for now but are keeping
the option here to not break existing code.
Defaults to False.
ranking_cache_size (:obj:`int`, optional): The number of unique
word disambiguations to cache. If 0, no ranked analyses will be
cached. The cache uses a least-frequently-used eviction policy.
Expand All @@ -392,12 +396,13 @@ def _pretrained_from_config(config, top=1, use_gpu=True, batch_size=32,
cache_size=cache_size)
scorer = model_config['scorer']
tie_breaker = model_config['tie_breaker']
ranking_cache = None
if pretrained_cache:
cache_path = model_config['ranking_cache']
with open(cache_path, 'rb') as f:
ranking_cache = pickle.load(f)
else:
ranking_cache = None
warnings.warn(
'The `pretrained_cache` argument is deprecated and will be '
'removed in a future release.',
DeprecationWarning,
stacklevel=2)

return BERTUnfactoredDisambiguator(
model_path,
Expand Down