From c9fb2c0327312e33e8724a6a60d7570ccb683c82 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Tue, 12 May 2026 06:05:28 +0000 Subject: [PATCH 01/10] Add MXFP8/NVFP4 quantization, quantized model init, collator state, and 70B support - Add quantization.py with layer-wise precision (MXFP8/NVFP4/BF16 per layer) - Add quantized_model_init support with FusedAdam FP32 master weights - Add stateful TokenPackingDataset for checkpoint resume across all collators - Add 70B Llama configs with context parallelism and THD format - Add checkpoint.py with FSDP2 save/load utilities - Update modeling_llama_te.py with per-layer FP8 recipe injection - Update train_fsdp2.py and train_fsdp2_cp.py for quantized training - Add comprehensive tests for quantization and checkpoint resume Co-Authored-By: Claude Opus 4.6 --- bionemo-recipes/models/esm2/collator.py | 13 + bionemo-recipes/models/llama3/collator.py | 13 + .../models/llama3/modeling_llama_te.py | 65 +++- .../llama3/tests/test_distributed_fp8.py | 243 +++++++++++++ .../llama3/tests/test_layer_quantization.py | 180 ++++++++++ bionemo-recipes/models/mixtral/collator.py | 13 + bionemo-recipes/models/qwen/collator.py | 13 + .../recipes/esm2_native_te/collator.py | 13 + .../recipes/esm2_peft_te/collator.py | 13 + .../recipes/llama3_native_te/Dockerfile | 11 +- .../recipes/llama3_native_te/README.md | 94 +++-- .../recipes/llama3_native_te/checkpoint.py | 76 ++++ .../recipes/llama3_native_te/collator.py | 13 + .../recipes/llama3_native_te/dataset.py | 6 +- .../llama3_native_te/fp4_debugging_stats.yaml | 33 ++ .../recipes/llama3_native_te/fp8_debugging.py | 64 ---- .../llama3_native_te/fp8_debugging_stats.yaml | 7 +- .../hydra_config/L2_lingua_70b.yaml | 73 ++++ .../hydra_config/L2_lingua_70b_mxfp8.yaml | 23 ++ .../hydra_config/L2_lingua_70b_mxfp8_cp4.yaml | 20 ++ .../L2_lingua_70b_mxfp8_qinit.yaml | 23 ++ .../L2_lingua_70b_mxfp8_qinit_thd.yaml | 29 ++ .../hydra_config/L2_lingua_70b_mxfp8_thd.yaml | 18 + .../L2_lingua_70b_mxfp8_thd_cp4.yaml | 20 ++ .../hydra_config/L2_lingua_70b_thd.yaml | 16 + .../hydra_config/L2_lingua_7b.yaml | 63 ++++ .../L2_lingua_7b_bf16_baseline.yaml | 12 + .../hydra_config/L2_lingua_7b_fp8.yaml | 21 ++ .../hydra_config/L2_lingua_7b_mxfp8.yaml | 22 ++ .../L2_lingua_7b_mxfp8_fl1_qinit.yaml | 28 ++ .../L2_lingua_7b_mxfp8_qinit.yaml | 22 ++ .../hydra_config/L2_lingua_7b_pure_bf16.yaml | 13 + .../hydra_config/defaults.yaml | 19 +- .../meta-llama/Llama-3.1-70B/config.json | 35 ++ .../meta-llama/Llama-3.1-8B/config.json | 35 ++ .../llama3_native_te/modeling_llama_te.py | 65 +++- .../recipes/llama3_native_te/perf_logger.py | 10 +- .../recipes/llama3_native_te/quantization.py | 223 ++++++++++++ .../recipes/llama3_native_te/requirements.txt | 2 - .../llama3_native_te/tests/conftest.py | 2 + .../test_mxfp8_fsdp2_checkpoint_resume.py | 311 ++++++++++++++++ .../tests/test_perf_logger.py | 2 +- .../tests/test_quantization.py | 332 ++++++++++++++++++ .../tests/test_quantized_model_init.py | 163 +++++++++ .../llama3_native_te/tests/test_train.py | 120 ++++++- .../recipes/llama3_native_te/train_ddp.py | 79 ++++- .../recipes/llama3_native_te/train_fsdp2.py | 250 +++++++++++-- .../llama3_native_te/train_fsdp2_cp.py | 128 ++++++- .../opengenome2_llama_native_te/collator.py | 13 + .../modeling_llama_te.py | 65 +++- 50 files changed, 2928 insertions(+), 199 deletions(-) create mode 100644 bionemo-recipes/models/llama3/tests/test_distributed_fp8.py create mode 100644 bionemo-recipes/models/llama3/tests/test_layer_quantization.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_qinit.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-70B/config.json create mode 100644 bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json create mode 100644 bionemo-recipes/recipes/llama3_native_te/quantization.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/tests/test_mxfp8_fsdp2_checkpoint_resume.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py create mode 100644 bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py diff --git a/bionemo-recipes/models/esm2/collator.py b/bionemo-recipes/models/esm2/collator.py index e83d719eb7..add487c997 100644 --- a/bionemo-recipes/models/esm2/collator.py +++ b/bionemo-recipes/models/esm2/collator.py @@ -335,6 +335,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index 4555c1762a..058b625bec 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -341,6 +341,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index 0b2236b5bf..1d912f1ba1 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -52,6 +52,7 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + layer_precision: list[str | None] | None = None def __init__( self, @@ -217,11 +218,54 @@ def _init_method(x): self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + def set_recipes( + self, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ) -> None: + """Attach quantization recipe objects for per-layer autocast. + + Recipes are not serializable and must be set at runtime after model creation + and sharding (FSDP/DDP) but before training. The per-layer precision + assignments are read from ``self.config.layer_precision``. + + Args: + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, input_ids: torch.Tensor | None = None, @@ -298,12 +342,14 @@ def forward( if te_rope_emb.dtype != torch.float32: warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning) - with self.get_autocast_context(None, outer=True): - for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled + # by get_layer_autocast(), which nests inside this context. + with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe): + for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): if output_hidden_states: all_hidden_states = (*all_hidden_states, hidden_states) - with self.get_autocast_context(layer_idx): + with self.get_layer_autocast(layer_number): hidden_states = decoder_layer( hidden_states, attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, @@ -363,8 +409,12 @@ def get_autocast_context( if init and self.config.use_quantized_model_init: if precision in ("fp8", "fp4"): - return transformer_engine.pytorch.quantized_model_init(recipe=recipe) - return nullcontext() + # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext() + # preserves the outer context's settings (recipe, preserve_high_precision_init_val). + # A nested quantized_model_init would override preserve_high_precision_init_val to False. + return nullcontext() + # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context. + return transformer_engine.pytorch.quantized_model_init(enabled=False) if precision == "fp8": if recipe is None: @@ -591,8 +641,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor): updated_key_cache = key_cache.index_select(0, beam_idx) updated_value_cache = value_cache.index_select(0, beam_idx) self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) - - @property - def is_compileable(self) -> bool: - """Return False as this cache is not compatible with torch.compile.""" - return False diff --git a/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py new file mode 100644 index 0000000000..eb93415d50 --- /dev/null +++ b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import subprocess + +import pytest +import torch +from transformer_engine.pytorch.fp8 import check_fp8_support + + +def requires_fp8(func): + """Decorator to skip tests that require FP8 support.""" + fp8_available, reason = check_fp8_support() + return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func) + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"]) +@requires_fp8 +def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): + cmd = [ + "torchrun", + "--nproc_per_node=1", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + os.path.relpath(__file__), + "--strategy", + strategy, + ] + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"]) +@requires_fp8 +@requires_multi_gpu +def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port): + cmd = [ + "torchrun", + "--nproc_per_node=2", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + os.path.relpath(__file__), + "--strategy", + strategy, + ] + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +if __name__ == "__main__": + import argparse + import enum + import os + import sys + from dataclasses import dataclass, field + from pathlib import Path + + # Ensure the model directory is on sys.path for bare module imports. + sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix()) + + import torch.distributed as dist + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import fully_shard + from torch.optim import AdamW + from transformer_engine.pytorch.fp8 import DelayedScaling, Format + + from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM + + def recursive_assert(a, b, path=""): + if isinstance(a, dict) and isinstance(b, dict): + assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}" + for k in a: + recursive_assert(a[k], b[k], path=f"{path}.{k}") + elif isinstance(a, list) and isinstance(b, list): + assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}" + for i in range(len(a)): + recursive_assert(a[i], b[i], path=f"{path}.{i}") + elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}") + else: + assert a == b, f"Value mismatch at {path}: {a} != {b}" + + class Strategy(enum.StrEnum): + DDP = "ddp" + FSDP2 = "fsdp2" + + @dataclass + class DistributedConfig: + """Class to track distributed ranks.""" + + rank: int = field(default_factory=dist.get_rank) + local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"])) + world_size: int = field(default_factory=dist.get_world_size) + + def is_main_process(self) -> bool: + """This is the global rank 0 process, to be used for wandb logging, etc.""" + return self.rank == 0 + + parser = argparse.ArgumentParser() + parser.add_argument("--strategy", type=Strategy, default=Strategy.DDP, choices=[Strategy.FSDP2, Strategy.DDP]) + args = parser.parse_args() + + torch.distributed.init_process_group(backend="nccl") + dist_config = DistributedConfig() + torch.cuda.set_device(dist_config.local_rank) + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist_config.world_size, 1), + mesh_dim_names=("dp", "tp"), + ) + device = f"cuda:{dist_config.local_rank}" + + fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10) + + config = NVLlamaConfig( + hidden_size=256, + intermediate_size=512, + num_hidden_layers=6, + num_attention_heads=8, + num_key_value_heads=4, + vocab_size=100, + dtype=torch.bfloat16, + ) + config.layer_precision = ["fp8"] * config.num_hidden_layers + model = NVLlamaForCausalLM(config) + + if args.strategy is Strategy.FSDP2: + for layer in model.model.layers: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + model.to(device) + + elif args.strategy is Strategy.DDP: + model.to(device) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["dp"], + ) + + optimizer = AdamW(model.parameters()) + + # Attach FP8 recipes to the model (layer precision is already on config). + llama_model = model.module.model if args.strategy is Strategy.DDP else model.model + llama_model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) + + model.train() + + generator = torch.Generator() + generator.manual_seed(torch.distributed.get_rank()) + + for _ in range(3): + input_data = { + "input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator), + "labels": torch.randint(0, config.vocab_size, (1, 32), generator=generator), + "attention_mask": torch.ones(1, 32), + } + input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()} + + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + outputs = model(**input_data) + + outputs.loss.backward() + + # Access FP8 extra states directly from modules instead of state_dict() + # since state_dict() now filters them out for HuggingFace compatibility + fp8_extra_states = {} + for name, module in model.named_modules(): + if hasattr(module, "_extra_state") and callable(module._extra_state): + extra_state = module._extra_state() + if extra_state is not None and len(extra_state) > 0: + fp8_extra_states[f"{name}._extra_state"] = extra_state + + # lm_head is BF16, not FP8, so exclude it from FP8 checks + fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key} + + # 2 ranks, test to ensure that both ranks have the same FP8 extra states + if torch.distributed.get_world_size() == 2: + outputs_list = [None] * torch.distributed.get_world_size() if torch.distributed.get_rank() == 0 else None + torch.distributed.gather_object(fp8_extra_states, outputs_list, dst=0) + if torch.distributed.get_rank() == 0: + assert outputs_list is not None + + for key in outputs_list[0]: + state_1 = outputs_list[0][key] + state_2 = outputs_list[1][key] + assert len(state_1) > 0, f"No FP8 extra states for {key}, rank 0" + assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1" + dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes()) + dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes()) + recursive_assert(dict_1, dict_2) + + # One rank, test to ensure the correct FP8 extra states are saved + if torch.distributed.get_world_size() == 1: + for key, val in fp8_extra_states.items(): + assert len(val) > 0, f"No FP8 extra states for {key}" + fp8_meta_dict = pickle.loads(val.detach().numpy(force=True).tobytes()) + assert fp8_meta_dict["recipe"] == fp8_recipe, f"Recipe mismatch for {key}" + + torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/models/llama3/tests/test_layer_quantization.py b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py new file mode 100644 index 0000000000..a80ff80f2c --- /dev/null +++ b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for NVLlamaModel.set_recipes and get_layer_autocast.""" + +from contextlib import nullcontext +from unittest.mock import patch + +import pytest +import transformer_engine.common.recipe +import transformer_engine.pytorch + +from modeling_llama_te import NVLlamaConfig, NVLlamaModel + + +@pytest.fixture +def model(): + """Create a small NVLlamaModel for testing.""" + config = NVLlamaConfig( + hidden_size=256, + intermediate_size=512, + num_hidden_layers=6, + num_attention_heads=8, + num_key_value_heads=4, + vocab_size=100, + ) + return NVLlamaModel(config) + + +# -- set_recipes -- + + +def test_all_fp8(model): + model.config.layer_precision = ["fp8"] * 6 + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) + assert model._fp8_recipe is fp8_recipe + assert model._fp4_recipe is None + assert all(p == "fp8" for p in model.config.layer_precision) + + +def test_all_fp4(model): + model.config.layer_precision = ["fp4"] * 6 + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe) + assert model._fp8_recipe is None + assert model._fp4_recipe is fp4_recipe + assert all(p == "fp4" for p in model.config.layer_precision) + + +def test_all_bf16(model): + model.config.layer_precision = [None] * 6 + model.set_recipes(fp8_recipe=None, fp4_recipe=None) + assert all(p is None for p in model.config.layer_precision) + + +def test_mixed_fp8_fp4(model): + model.config.layer_precision = ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + assert model.config.layer_precision == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + + +def test_mixed_fp8_bf16(model): + model.config.layer_precision = ["fp8", None, "fp8", None, "fp8", None] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) + assert model.config.layer_precision == ["fp8", None, "fp8", None, "fp8", None] + + +def test_mixed_all_three(model): + model.config.layer_precision = ["fp8", "fp8", None, None, "fp4", "fp4"] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + assert model.config.layer_precision == ["fp8", "fp8", None, None, "fp4", "fp4"] + + +def test_covers_all_layers(model): + model.config.layer_precision = ["fp8"] + [None] * 5 + model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None) + assert len(model.config.layer_precision) == 6 + + +def test_recipes_stored_as_attributes(model): + model.config.layer_precision = ["fp8", "fp4", None, None, None, None] + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + assert model._fp8_recipe is fp8_recipe + assert model._fp4_recipe is fp4_recipe + # The precision list only contains strings/None, not recipe objects. + for v in model.config.layer_precision: + assert v is None or isinstance(v, str) + + +# -- get_layer_autocast -- + + +def test_fp8_layer_returns_nullcontext(model): + model.config.layer_precision = ["fp8"] + [None] * 5 + model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None) + ctx = model.get_layer_autocast(0) + assert isinstance(ctx, nullcontext) + + +def test_fp4_layer_returns_te_autocast(model): + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.config.layer_precision = ["fp4"] + [None] * 5 + model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "fp4_context" + ctx = model.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe) + assert ctx == "fp4_context" + + +def test_bf16_layer_returns_te_autocast_disabled(model): + model.config.layer_precision = [None] * 6 + model.set_recipes(fp8_recipe=None, fp4_recipe=None) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + ctx = model.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=False) + assert ctx == "bf16_context" + + +def test_uninitialized_defaults_to_bf16(model): + """When layer_precision is None (default), all layers default to BF16.""" + assert model.config.layer_precision is None + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + ctx = model.get_layer_autocast(0) + mock_autocast.assert_called_once_with(enabled=False) + assert ctx == "bf16_context" + + +def test_mixed_layers_return_correct_contexts(model): + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() + model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None] + model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + + # FP8 layers -> nullcontext + assert isinstance(model.get_layer_autocast(0), nullcontext) + assert isinstance(model.get_layer_autocast(1), nullcontext) + + # FP4 layers -> te.pytorch.autocast + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "fp4_context" + model.get_layer_autocast(2) + mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe) + + # BF16 layers -> te.pytorch.autocast(enabled=False) + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: + mock_autocast.return_value = "bf16_context" + model.get_layer_autocast(4) + mock_autocast.assert_called_with(enabled=False) + + +def test_layer_precision_is_pickleable(model): + """The config.layer_precision list should be trivially pickleable.""" + import pickle + + model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None] + roundtripped = pickle.loads(pickle.dumps(model.config.layer_precision)) + assert roundtripped == model.config.layer_precision diff --git a/bionemo-recipes/models/mixtral/collator.py b/bionemo-recipes/models/mixtral/collator.py index 4555c1762a..058b625bec 100644 --- a/bionemo-recipes/models/mixtral/collator.py +++ b/bionemo-recipes/models/mixtral/collator.py @@ -341,6 +341,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/models/qwen/collator.py b/bionemo-recipes/models/qwen/collator.py index 4555c1762a..058b625bec 100644 --- a/bionemo-recipes/models/qwen/collator.py +++ b/bionemo-recipes/models/qwen/collator.py @@ -341,6 +341,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index 4555c1762a..058b625bec 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -341,6 +341,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/esm2_peft_te/collator.py b/bionemo-recipes/recipes/esm2_peft_te/collator.py index 4555c1762a..058b625bec 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/collator.py +++ b/bionemo-recipes/recipes/esm2_peft_te/collator.py @@ -341,6 +341,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/llama3_native_te/Dockerfile b/bionemo-recipes/recipes/llama3_native_te/Dockerfile index b72c36b890..c3dd031de2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/Dockerfile +++ b/bionemo-recipes/recipes/llama3_native_te/Dockerfile @@ -1,5 +1,14 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:26.04-py3 +FROM nvcr.io/nvidia/pytorch:26.03-py3 + +# Rebuild TransformerEngine from main branch (includes PR #2753: MXFP8 FusedAdam support). +# To pin a specific commit, replace 'main' with a commit hash. +# Build: docker build -t llama3_native_te:te-main-26.03 . +RUN pip uninstall -y transformer_engine transformer_engine_torch transformer_engine_cu12 && \ + git clone --recursive https://github.com/NVIDIA/TransformerEngine.git /opt/te && \ + cd /opt/te && git checkout main && \ + NVTE_FRAMEWORK=pytorch MAX_JOBS=8 pip install --no-build-isolation . && \ + rm -rf /opt/te RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index 2be3b0f11e..67ca194452 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -1,8 +1,8 @@ # TransformerEngine-accelerated Llama 3 training with native PyTorch training loop This folder demonstrates how to train TE-accelerated Llama 3 with a native PyTorch training loop, including sequence -packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training. This recipe is configured -for genomic sequences using a custom nucleotide tokenizer. +packing, FP8/MXFP8/NVFP4 precision with layer-wise control, using fully sharded data parallel (FSDP) for distributed +training. This recipe is configured for genomic sequences using a custom nucleotide tokenizer. ## How to use this recipe @@ -16,9 +16,9 @@ bionemo-framework repository. You can download a zipped directory of this folder ## Supported Models and Training Features -| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism | -| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ | -| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | +| Model | BF16 | FP8[1] | MXFP8[2] | NVFP4[3] | THD Input Format | Context Parallelism | Tensor Parallelism | +| ---------------------------------------- | ---- | ----------------- | ------------------- | ------------------- | ---------------- | ------------------- | ------------------ | +| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅: Supported
🚧: Under development
@@ -26,6 +26,7 @@ bionemo-framework repository. You can download a zipped directory of this folder \[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+)
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending
+\[3\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and above (Blackwell+)
### Installing Dependencies @@ -64,19 +65,24 @@ def compute_model_pflops(seq_len, global_batch_size, step_time_s): return model_flops / 1e15 ``` +### Low precision performance benchmarks + +![Performance Benchmarks Low Precision](../../../docs/docs/assets/images/llama3/llama3_8gpu_tflops.png) +In the above plot we can see the performance increases as we lower the precision of our transformer layers across the 1B and 8B variant of LLAMA3. + ### Convergence Benchmarks

- Llama 3 Lingua 1B Loss Curve - Llama 3 Lingua 1B Step Time + Llama 3 Lingua 1B Loss Curve + Llama 3 Lingua 1B Step Time

-We compared the convergence of this Llama3 recipe (with FSDP2) against -[NeMo 2.0](https://github.com/NVIDIA-NeMo/NeMo) and the [facebookresearch/lingua](https://github.com/facebookresearch/lingua) +We compared the convergence of this Llama3 recipe (with FSDP2) against NeMo 2.0 +(https://github.com/NVIDIA-NeMo/NeMo) and the [facebookresearch/lingua](https://github.com/facebookresearch/lingua) implementation on the DCLM Baseline 1.0 dataset. See [Training on Natural Language Data (Lingua -Reproduction)](#training-on-natural-language-data-lingua-reproduction) for more details. The figure above shows similar loss convergence and step time to +Reproduction)](#lingua-reproduction) for more details. The figure above shows similar loss convergence and step time to the NeMo 2.0 training example, and the following table shows downstream performance on various tasks using the -[lm-eval](https://github.com/eleutherai/lm-evaluation-harness) library. The variation in training step time every 10,000 steps +[lm-eval](github.com/eleutherai/lm-evaluation-harness) library. The variation in training step time every 10,000 steps are due checkpointing, further work will be done to improve training step time stability. | name | arc_challenge | arc_easy | boolq | copa | hella_swag | piqa | winogrande | @@ -88,6 +94,10 @@ are due checkpointing, further work will be done to improve training step time s Models were trained on 64 NVIDIA H100 GPUs with a micro batch size of 4 and a context length of 4096 for 60,000 steps. Training was performed with BF16 precision. +### Low Precision convergence benchmarks + + + ### Distributed Training This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Context Parallelism, shown in three separate training entrypoints: @@ -127,10 +137,10 @@ batch size while running on a smaller number of GPUs. python train_fsdp2.py --config-name L0_sanity grad_acc_steps=2 ``` -### FP8 Training +### Quantized Training (FP8 / MXFP8 / NVFP4) To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8 -configuration parameters, including switching to `MXFP8BlockScaling`, can be set via the hydra configuration. +configuration parameters, including switching to `MXFP8BlockScaling`, can be set using the hydra configuration. ```bash python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true @@ -150,24 +160,60 @@ python train_fsdp2.py --config-name L0_sanity \ #### FP8 Debugging -We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients. +```bash +python train_fsdp2.py --config-name L0_sanity fp4_config.enabled=true +``` + +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet +supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today. + +Additional recipe parameters (e.g., switching to `MXFP8BlockScaling`) can be set via the hydra configuration. + +#### Layer-Wise Precision + +You can control which transformer layers use FP8 or FP4 by specifying 1-indexed layer numbers via `fp8_layers` and +`fp4_layers`. Layers not assigned to either format will run in BF16. + +For example, to run layers 1-3 in FP8, layers 4-6 in FP4, and the rest in BF16 on a model with more than 6 layers: + +```bash +python train_fsdp2.py --config-name L0_sanity \ + fp8_config.enabled=true \ + fp4_config.enabled=true \ + 'fp8_layers=[1,2,3]' \ + 'fp4_layers=[4,5,6]' +``` + +When both `fp8_config` and `fp4_config` are enabled but only one layer list is provided, the other format automatically +claims the remaining layers. For example, if `fp8_layers=[1,2,3]` is set and `fp4_config.enabled=true` with no +`fp4_layers`, then layers 4 through N will default to FP4. + +#### Quantization Stats Debugging + +We provide a mechanism to log tensor statistics (activations, weights, gradients) for quantized layers during training. +When layer-wise precision is used, the stats config is automatically updated so that only the relevant layers are +tracked. -To enable this please select the following config options. +To enable stats logging: ```bash python train_fsdp2.py \ - fp8_stats_config.enabled=True \ - fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \ - fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml \ - fp8_config.enabled=True + quant_stats_config.enabled=true \ + quant_stats_config.quant_log_dir=./logs/quant_stats \ + quant_stats_config.quant_stats_file=./fp8_debugging_stats.yaml \ + fp8_config.enabled=true ``` -Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet +supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today. -The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure. +The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the +[NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) +in more detail. -This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our -experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit. +Stats collection has a performance cost dependent on the `freq` parameter in the config file. `freq=1` collects stats +on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We +recommend using `freq>=10` to reduce this performance hit. ### Sequence Packing (THD input format) @@ -217,7 +263,7 @@ python train_fsdp2.py --config-name L0_sanity \ dataset.load_dataset_kwargs.path=/path/to/download/directory ``` -## Training on Natural Language Data (Lingua Reproduction) +## Training on Natural Language Data (Lingua Reproduction) {#lingua-reproduction} We provide a configuration to reproduce the Llama-3.2-1B training experiments from [Meta Lingua](https://github.com/facebookresearch/lingua), using the [DCLM Baseline diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py index 2dc5d10dcf..d0accd200d 100644 --- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py @@ -34,6 +34,7 @@ from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save from torch.distributed.checkpoint.state_dict_saver import save as dcp_save from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam as _FSDPParam from torch.distributed.tensor import DTensor from torchdata.stateful_dataloader import StatefulDataLoader from transformer_engine.pytorch.quantized_tensor import QuantizedTensor @@ -41,6 +42,81 @@ from distributed_config import DistributedConfig +# --------------------------------------------------------------------------- +# Monkey-patch FSDP2's FSDPParam.reset_sharded_param to handle QuantizedTensor. +# +# After checkpoint load, set_state_dict calls copy_() on FSDP-sharded params. +# For QuantizedTensor (MXFP8Tensor), copy_() re-quantizes which can invalidate +# the old untyped_storage, causing data_ptr() to crash. The original code +# (PyTorch _fsdp_param.py) compares storage pointers without guarding against +# QuantizedTensor. This patch wraps the comparison in a try/except so that +# reset_sharded_param can proceed normally (re-recording _sharded_param_data). +# --------------------------------------------------------------------------- + + +def _patched_reset_sharded_param(self): # type: ignore[no-untyped-def] + """reset_sharded_param with QuantizedTensor safety.""" + module_info = self._module_info + new_param = getattr(module_info.module, module_info.param_name) + if new_param is not self.sharded_param: + if torch.__future__.get_swap_module_params_on_conversion(): + raise AssertionError( + f"Expects swap_tensors to preserve object but got {new_param} instead of {self.sharded_param}" + ) + self.sharded_param = new_param + + local_tensor = new_param._local_tensor + if local_tensor.is_meta: + return + + updated_local_tensor = False + same_local_tensor = False + + if type(self._sharded_param_data) is torch.Tensor: + try: + same_local_tensor = ( + self._sharded_param_data.untyped_storage().data_ptr() > 0 + and self._sharded_param_data.untyped_storage().data_ptr() == local_tensor.untyped_storage().data_ptr() + ) + except RuntimeError: + # QuantizedTensor (e.g. MXFP8Tensor) can have invalid storage + # after copy_() re-quantization. Treat as not-same so that + # _sharded_param_data gets re-recorded below. + same_local_tensor = False + + padded_sharded_size = self.padded_sharded_param_size + shard_dim = self.fsdp_placement.dim + length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 + + if local_tensor.size() != padded_sharded_size and not same_local_tensor: + if shard_dim != 0: + raise AssertionError(f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}") + padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) + padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(local_tensor) + local_tensor = padded_local_tensor + updated_local_tensor = True + + if self.pin_memory and not local_tensor.is_pinned(): + local_tensor = local_tensor.cpu().pin_memory() + updated_local_tensor = True + + if not same_local_tensor: + self._sharded_param_data = local_tensor.view(-1) + + if not isinstance(self.sharded_param, DTensor): + raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}") + + if updated_local_tensor: + self.sharded_param._local_tensor = local_tensor.narrow(dim=shard_dim, start=0, length=length) + if not self.sharded_param._local_tensor.is_contiguous(): + raise AssertionError("Expected sharded_param._local_tensor to be contiguous") + + self._sharding_spec = self.sharded_param._spec + + +_FSDPParam.reset_sharded_param = _patched_reset_sharded_param + + logger = logging.getLogger(__name__) # Tracks in-flight async checkpoint futures keyed by strategy name (e.g. "fsdp2"). diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index 4555c1762a..058b625bec 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -341,6 +341,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py index adca62a43a..797e47c729 100644 --- a/bionemo-recipes/recipes/llama3_native_te/dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py @@ -192,7 +192,6 @@ def create_bshd_dataloader( data_collator = base_collator logger.info("Using standard DataCollatorForLanguageModeling") - # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again. dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader train_dataloader = dataloader_class( tokenized_dataset, @@ -200,7 +199,7 @@ def create_bshd_dataloader( batch_size=micro_batch_size, collate_fn=data_collator, num_workers=num_workers, - pin_memory=not use_stateful_dataloader, + pin_memory=True, persistent_workers=num_workers > 0, prefetch_factor=prefetch_factor if num_workers > 0 else None, ) @@ -288,7 +287,6 @@ def create_thd_dataloader( f"Using GenomicDataCollator (uppercase={uppercase_labels}, mask_degenerate={mask_degenerate_bases})" ) - # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again. dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader train_dataloader = dataloader_class( TokenPackingDataset( @@ -299,7 +297,7 @@ def create_thd_dataloader( batch_size=None, # The TokenPackingDataset will handle the batching. collate_fn=data_collator, num_workers=num_workers, - pin_memory=not use_stateful_dataloader, + pin_memory=True, persistent_workers=num_workers > 0, prefetch_factor=prefetch_factor if num_workers > 0 else None, ) diff --git a/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml new file mode 100644 index 0000000000..9046d44caf --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml @@ -0,0 +1,33 @@ +example_fp4_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming) + # This matches: model.model.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.model\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 100 + - tensor: gradient + stats: [underflows%, mse] + freq: 100 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming) + # This matches: model.model.layers.[6-10].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.model\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 + - tensor: gradient + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py deleted file mode 100644 index d01024f04c..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -from pathlib import Path - -import nvdlfw_inspect.api as debug_api -import transformer_engine - -from distributed_config import DistributedConfig - - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -def initialize_fp8_debugging( - dist_config: DistributedConfig, - enabled: bool, - fp8_stats_file: str, - fp8_log_dir: str | os.PathLike, - fp8_enabled: bool, -) -> None: - """Initialize FP8 debugging. - - Args: - dist_config: The distributed configuration. - enabled: Whether to enable FP8 debugging. - fp8_stats_file: The file containing the FP8 stats. - fp8_log_dir: The directory to log the FP8 stats to. - fp8_enabled: Whether FP8 autocast is enabled. - """ - if not enabled: - return - - if not fp8_enabled: - raise ValueError( - "fp8_stats_config.enabled is true but fp8_config.enabled is false, " - "please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" - ) - - fp8_log_dir = Path(fp8_log_dir) / f"rank_{dist_config.rank}" - fp8_log_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Logging FP8 stats to {fp8_log_dir}") - te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") - debug_api.initialize( - config_file=fp8_stats_file, - feature_dirs=[te_features_dir], - log_dir=fp8_log_dir.as_posix(), - default_logging_enabled=True, - ) diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml index 7544bbedcf..ba640a6cbb 100644 --- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection: enabled: True layers: # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv] + layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: LogFp8TensorStats: enabled: True @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection: - tensor: weight stats: [underflows%, scale_inv_min, scale_inv_max, mse] freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad] + freq: 1 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml new file mode 100644 index 0000000000..15069fde62 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml @@ -0,0 +1,73 @@ +# Lingua 70B BF16 with Context Parallelism (CP=2). +# Uses train_fsdp2_cp.py, not train_fsdp2.py. + +defaults: + - defaults + - _self_ + +config_name_or_path: ./model_configs/meta-llama/Llama-3.1-70B + +config_kwargs: + attn_input_format: bshd + self_attn_mask_type: causal + +# CP=2 halves per-GPU sequence length, cutting attention activation memory ~4x. +cp_size: 2 + +# BSHD format required for CP (no sequence packing). +use_sequence_packing: false + +# Meta device init is critical for 70B to avoid OOM during model construction. +use_meta_device: true + +# FP32 master weights via TE FusedAdam +use_fp32_master_weights_fused: true + +wandb: + name: lingua-70b-bf16-cp2 + project: lingua-70b + id: lingua-70b-bf16-cp2 + +num_train_steps: 60_000 + +dataset: + tokenizer_name_or_path: ./tokenizers/Meta-Llama-3-8B + micro_batch_size: 1 + num_workers: 4 + max_seq_length: 8192 + stride: 512 + buffer_size: 10_000 + use_stateful_dataloader: false + mask_degenerate_bases: false + uppercase_labels: false + load_dataset_kwargs: + path: parquet + data_files: "/workspace/data/dclm-baseline/global-shard_01_of_10/**/*.parquet" + split: "train" + streaming: true + +# With CP=2, dp_size = 32/2 = 16. GBS = 1 * 16 * 16 = 256 +grad_acc_steps: 16 + +adamw_kwargs: + lr: 3e-4 + fused: true + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 0.01 + +lr_scheduler_kwargs: + num_warmup_steps: 5_000 + num_decay_steps: 55_000 + min_lr_ratio: 0.000001 + +# Checkpoint config +checkpoint: + ckpt_dir: null + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 1000 + async_save: true + +profiler: + enabled: false diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml new file mode 100644 index 0000000000..d28022dad0 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml @@ -0,0 +1,23 @@ +# Lingua 70B MXFP8 with Context Parallelism (CP=2). +# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. + +defaults: + - L2_lingua_70b + - _self_ + +# FP8 with MXFP8BlockScaling (hardware accelerated on Blackwell) +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling + fp8_format: E4M3 + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: false + +# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16 +fp8_layers: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] + +wandb: + name: lingua-70b-mxfp8-fl1-cp2 + id: lingua-70b-mxfp8-fl1-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml new file mode 100644 index 0000000000..c34d88124b --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml @@ -0,0 +1,20 @@ +# Lingua 70B MXFP8 with Context Parallelism (CP=4). +# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. +# CP=4 reduces per-GPU activation memory, needed for B200 GPUs (192GB) +# where CP=2 OOMs with FP32 master weights. +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. + +defaults: + - L2_lingua_70b_mxfp8 + - _self_ + +cp_size: 4 + +dataset: + # MXFP8 block size is 32. With CP=4, post-split dims must be divisible by 32. + # Pre-split: pad to 128 (= 32 * CP). After CP splits by 4, each chunk is divisible by 32. + pad_sequences_to_be_divisible_by: 128 + +wandb: + name: lingua-70b-mxfp8-fl1-cp4 + id: lingua-70b-mxfp8-fl1-cp4 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit.yaml new file mode 100644 index 0000000000..d226d32214 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit.yaml @@ -0,0 +1,23 @@ +# Lingua 70B MXFP8 with quantized model init (all layers FP8) + CP=2. +# Quantized model init stores only FP8 weights (no BF16 copies), saving memory. +# FP32 master weights are maintained in FusedAdam optimizer. +# BSHD format (no sequence packing) — required for CP. +# Requires Blackwell GPUs (GB200/B300) for hardware MXFP8 support. + +defaults: + - L2_lingua_70b + - _self_ + +# All layers in FP8 (no FL1 exclusion) — compatible with quantized_model_init. +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling + fp8_format: E4M3 + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: true + preserve_high_precision_init_val: true + +wandb: + name: lingua-70b-mxfp8-qinit-cp2 + id: lingua-70b-mxfp8-qinit-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml new file mode 100644 index 0000000000..c08093bfca --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml @@ -0,0 +1,29 @@ +# Lingua 70B MXFP8 THD with quantized model init (all layers FP8) + CP=2. +# Quantized model init stores only FP8 weights (no BF16 copies), saving memory. +# FP32 master weights are maintained in FusedAdam optimizer. +# THD enables sequence packing for better GPU utilization. +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. + +defaults: + - L2_lingua_70b + - _self_ + +config_kwargs: + attn_input_format: thd + self_attn_mask_type: padding + +use_sequence_packing: true + +# All layers in FP8 (no FL1 exclusion) — compatible with quantized_model_init. +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling + fp8_format: E4M3 + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: true + preserve_high_precision_init_val: true + +wandb: + name: lingua-70b-mxfp8-qinit-thd-cp2 + id: lingua-70b-mxfp8-qinit-thd-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml new file mode 100644 index 0000000000..fd30296024 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml @@ -0,0 +1,18 @@ +# Lingua 70B MXFP8 THD format with Context Parallelism (CP=2). +# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. +# THD enables sequence packing for better GPU utilization. +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. + +defaults: + - L2_lingua_70b_mxfp8 + - _self_ + +config_kwargs: + attn_input_format: thd + self_attn_mask_type: padding + +use_sequence_packing: true + +wandb: + name: lingua-70b-mxfp8-fl1-thd-cp2 + id: lingua-70b-mxfp8-fl1-thd-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml new file mode 100644 index 0000000000..8547047dd3 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml @@ -0,0 +1,20 @@ +# Lingua 70B MXFP8 THD format with Context Parallelism (CP=4). +# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. +# THD enables sequence packing for better GPU utilization. +# CP=4 reduces per-GPU activation memory, needed for B200 GPUs (192GB) +# where CP=2 OOMs with FP32 master weights. +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. + +defaults: + - L2_lingua_70b_mxfp8_cp4 + - _self_ + +config_kwargs: + attn_input_format: thd + self_attn_mask_type: padding + +use_sequence_packing: true + +wandb: + name: lingua-70b-mxfp8-fl1-thd-cp4 + id: lingua-70b-mxfp8-fl1-thd-cp4 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml new file mode 100644 index 0000000000..9143dc205d --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml @@ -0,0 +1,16 @@ +# Lingua 70B BF16 THD format with Context Parallelism (CP=2). +# THD enables sequence packing for better GPU utilization. + +defaults: + - L2_lingua_70b + - _self_ + +config_kwargs: + attn_input_format: thd + self_attn_mask_type: padding + +use_sequence_packing: true + +wandb: + name: lingua-70b-bf16-thd-cp2 + id: lingua-70b-bf16-thd-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml new file mode 100644 index 0000000000..7ec8784a9c --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml @@ -0,0 +1,63 @@ +# Config to match the Llama-3.1-8B model pre-training experiments from https://github.com/facebookresearch/lingua. + +defaults: + - defaults + - _self_ + +config_name_or_path: ./model_configs/meta-llama/Llama-3.1-8B + +config_kwargs: + attn_input_format: thd + +use_sequence_packing: true + +# FP32 master weights via TE FusedAdam (recommended over MixedPrecisionPolicy) +use_fp32_master_weights_fused: true + +wandb: + name: lingua-7b-bf16 + project: lingua-7b + id: lingua-7b-bf16-v2 + +num_train_steps: 60_000 + +dataset: + tokenizer_name_or_path: ./tokenizers/Meta-Llama-3-8B + micro_batch_size: 2 + num_workers: 8 + max_seq_length: 8192 + stride: 512 + buffer_size: 10_000 + use_stateful_dataloader: false + mask_degenerate_bases: false + uppercase_labels: false + load_dataset_kwargs: + path: parquet + data_files: "/workspace/data/dclm-baseline/global-shard_01_of_10/**/*.parquet" + split: "train" + streaming: true + +grad_acc_steps: 4 # GBS = 2 * 4 * 32 GPUs = 256 (4 nodes) + +adamw_kwargs: + lr: 3e-4 + fused: true + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 0.01 + +lr_scheduler_kwargs: + num_warmup_steps: 5_000 + num_decay_steps: 55_000 # total_steps - num_warmup_steps = 60_000 - 5_000 + min_lr_ratio: 0.000001 + +# Checkpoint config +checkpoint: + ckpt_dir: null + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 10_000 + async_save: false + +profiler: + enabled: false diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml new file mode 100644 index 0000000000..51ca061a3b --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml @@ -0,0 +1,12 @@ +# BF16 baseline for Lingua 7B — benchmarking step time / throughput against MXFP8. +# Same model, dataset, and hyperparams as L2_lingua_7b_mxfp8_qinit but without FP8. + +defaults: + - L2_lingua_7b + - _self_ + +num_train_steps: 1_000 + +wandb: + name: lingua-7b-bf16-baseline + id: lingua-7b-bf16-baseline diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml new file mode 100644 index 0000000000..a787e0bff0 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml @@ -0,0 +1,21 @@ +# Lingua 7B FP8 Block Scaling - FL1 (layer 1 and 32 in BF16, layers 2-31 in FP8) + +defaults: + - L2_lingua_7b + - _self_ + +# FP8 with Float8BlockScaling and E4M3 format +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling + fp8_format: E4M3 + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: false + +# FL1: layers 2-31 in FP8, layers 1 and 32 in BF16 +fp8_layers: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] + +wandb: + name: lingua-7b-fp8-bs-fl1 + id: lingua-7b-fp8-bs-fl1-v2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml new file mode 100644 index 0000000000..5f4450d65b --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml @@ -0,0 +1,22 @@ +# Lingua 7B MXFP8 - FL1 (layer 1 and 32 in BF16, layers 2-31 in FP8) +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support + +defaults: + - L2_lingua_7b + - _self_ + +# FP8 with MXFP8BlockScaling (hardware accelerated on Blackwell) +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling + fp8_format: E4M3 + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: false + +# FL1: layers 2-31 in FP8, layers 1 and 32 in BF16 +fp8_layers: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] + +wandb: + name: lingua-7b-mxfp8-fl1 + id: lingua-7b-mxfp8-fl1 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml new file mode 100644 index 0000000000..db1327411c --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml @@ -0,0 +1,28 @@ +# Lingua 7B MXFP8 with quantized model init + FL1 (first/last layer BF16). +# Layers 2-31 in FP8, layers 1 and 32 in BF16. +# Quantized model init stores only FP8 weights (no BF16 copies), saving memory. +# FP32 master weights are maintained in FusedAdam optimizer. +# Requires config_kwargs.use_quantized_model_init=true so per-layer init +# correctly disables quantization for BF16 layers. +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. + +defaults: + - L2_lingua_7b_mxfp8 + - _self_ + +config_kwargs: + attn_input_format: thd + use_quantized_model_init: true + +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling + fp8_format: E4M3 + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: true + preserve_high_precision_init_val: true + +wandb: + name: lingua-7b-mxfp8-fl1-qinit + id: lingua-7b-mxfp8-fl1-qinit diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_qinit.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_qinit.yaml new file mode 100644 index 0000000000..c6261ce9c2 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_qinit.yaml @@ -0,0 +1,22 @@ +# Lingua 7B MXFP8 with quantized model init (all layers FP8). +# Quantized model init stores only FP8 weights (no BF16 copies), saving memory. +# FP32 master weights are maintained in FusedAdam optimizer. +# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. + +defaults: + - L2_lingua_7b + - _self_ + +# All layers in FP8 (no FL1 exclusion) — compatible with outer quantized_model_init. +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling + fp8_format: E4M3 + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: true + preserve_high_precision_init_val: true + +wandb: + name: lingua-7b-mxfp8-allfp8-qinit + id: lingua-7b-mxfp8-allfp8-qinit diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml new file mode 100644 index 0000000000..6defc71b88 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml @@ -0,0 +1,13 @@ +# Lingua 7B pure BF16 - no FP32 master weights + +defaults: + - L2_lingua_7b + - _self_ + +# No FP32 master weights at all +use_fp32_master_weights_fused: null +use_fp32_master_weights: null + +wandb: + name: lingua-7b-pure-bf16 + id: lingua-7b-pure-bf16 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index 9302a0758d..0c20326dbf 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -32,6 +32,8 @@ dataset: wandb: name: ??? project: null # Optional: set to your wandb project name + id: null # Set to a fixed ID to resume the same run across restarts + resume: allow # "allow" resumes if id exists, else creates new run # TransformerEngine FP8 config. See # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more information on @@ -41,6 +43,8 @@ fp8_config: fp8_recipe: transformer_engine.common.recipe.DelayedScaling fp8_format: "HYBRID" fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: false # If this is set to true, fp8_config.enabled must also be set to true. fp4_config: enabled: false @@ -74,10 +78,19 @@ checkpoint: logger: frequency: 100 -fp8_stats_config: +quant_stats_config: enabled: false - fp8_stats_file: ./fp8_debugging_stats.yaml - fp8_log_dir: ./log_fp8_stats + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. +fp8_layers: null +fp4_layers: null +use_fp32_master_weights_fused: null # Use TE FusedAdam for FP32 master weights + +gradient_debug: + enabled: false + log_every_n_steps: 1 profiler: enabled: false diff --git a/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-70B/config.json b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-70B/config.json new file mode 100644 index 0000000000..bd1408afc6 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-70B/config.json @@ -0,0 +1,35 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.3", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json new file mode 100644 index 0000000000..460f2f1b71 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json @@ -0,0 +1,35 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.3", + "use_cache": true, + "vocab_size": 128256 +} diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index 62171cd237..994c4f876b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -58,6 +58,7 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + layer_precision: list[str | None] | None = None def __init__( self, @@ -223,11 +224,54 @@ def _init_method(x): self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + def set_recipes( + self, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ) -> None: + """Attach quantization recipe objects for per-layer autocast. + + Recipes are not serializable and must be set at runtime after model creation + and sharding (FSDP/DDP) but before training. The per-layer precision + assignments are read from ``self.config.layer_precision``. + + Args: + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, input_ids: torch.Tensor | None = None, @@ -304,12 +348,14 @@ def forward( if te_rope_emb.dtype != torch.float32: warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning) - with self.get_autocast_context(None, outer=True): - for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled + # by get_layer_autocast(), which nests inside this context. + with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe): + for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): if output_hidden_states: all_hidden_states = (*all_hidden_states, hidden_states) - with self.get_autocast_context(layer_idx): + with self.get_layer_autocast(layer_number): hidden_states = decoder_layer( hidden_states, attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, @@ -369,8 +415,12 @@ def get_autocast_context( if init and self.config.use_quantized_model_init: if precision in ("fp8", "fp4"): - return transformer_engine.pytorch.quantized_model_init(recipe=recipe) - return nullcontext() + # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext() + # preserves the outer context's settings (recipe, preserve_high_precision_init_val). + # A nested quantized_model_init would override preserve_high_precision_init_val to False. + return nullcontext() + # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context. + return transformer_engine.pytorch.quantized_model_init(enabled=False) if precision == "fp8": if recipe is None: @@ -597,8 +647,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor): updated_key_cache = key_cache.index_select(0, beam_idx) updated_value_cache = value_cache.index_select(0, beam_idx) self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) - - @property - def is_compileable(self) -> bool: - """Return False as this cache is not compatible with torch.compile.""" - return False diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 726eb19e8e..4b1a8d4ec7 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -91,7 +91,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: self.grad_acc_step_count = 0 # Whether to step debug_api.step() after each step - self.fp8_stats_enabled = args.fp8_stats_config.enabled + self.quant_stats_config = args.quant_stats_config.enabled @nvtx.annotate("PerfLogger.log_micro_step", color="pink") def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast): @@ -150,7 +150,7 @@ def log_step( if self._profiler is not None: self._profiler.step(step) - if self.fp8_stats_enabled: + if self.quant_stats_config: debug_api.step() if step % self.logging_frequency == 0 and step > 0: @@ -201,15 +201,15 @@ def log_step( def finish(self): """Finish the logger and close the progress bar.""" + if self.quant_stats_config: + debug_api.end_debug() + if not self._dist_config.is_main_process(): return wandb.finish() self._progress_bar.close() - if self.fp8_stats_enabled: - debug_api.end_debug() - class NsightProfiler: """Nsight Systems profiler wrapper for performance analysis. diff --git a/bionemo-recipes/recipes/llama3_native_te/quantization.py b/bionemo-recipes/recipes/llama3_native_te/quantization.py new file mode 100644 index 0000000000..e479b13c02 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/quantization.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for layer-wise quantization configuration (FP8/FP4).""" + +import logging +import tempfile +from pathlib import Path + +import yaml + + +logger = logging.getLogger(__name__) + + +def generate_layer_regex(layer_numbers: list[int] | None) -> str: + """Generate a regex pattern to match specific layer numbers (1-indexed). + + The debug API (nvdlfw_inspect) uses 1-indexed layer names after ``infer_and_assign_layer_names``. + + Args: + layer_numbers: List of layer numbers (1-indexed, as shown in debug logs). + If empty or None, returns a pattern that matches nothing. + + Returns: + Regex pattern string for matching those layers' linear sublayers. + """ + if not layer_numbers: + return r"model\.model\.layers\.DISABLED_NO_LAYERS_SPECIFIED" + layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) + return rf"model\.model\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, +) -> str: + """Update the quant stats YAML config with layer-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of layer numbers for FP4 (1-indexed). + fp8_layers: List of layer numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (a temp file). + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + if "example_fp4_tensor_stat_collection" in config: + # TODO: Remove this block and replace with FP8-style regex update once a TransformerEngine + # release with LogNvfp4TensorStats support is available. At that point, this becomes: + # fp4_regex = generate_layer_regex(fp4_layers) + # config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + config["example_fp4_tensor_stat_collection"]["enabled"] = False + if fp4_layers: + logger.warning( + "NVFP4 quant stats logging is not yet supported (requires a future TransformerEngine release). " + f"Disabling FP4 stats collection for layers {fp4_layers}. FP8 stats will still be collected." + ) + else: + logger.info("FP4 stats section disabled (no FP4 layers and feature not yet supported)") + + if "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + if fp8_layers: + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + else: + logger.info("FP8 layers empty - regex set to match nothing") + + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + +def initialize_quant_stats_logging( + quant_stats_file: str, + quant_log_dir: str, + rank: int, + layer_precision: list[str | None], +) -> None: + """Set up quantization stats logging via nvdlfw_inspect. + + Updates the quant stats YAML config with resolved layer regex patterns, creates + the per-rank log directory, and initializes the debug API. + + Args: + quant_stats_file: Path to the base quant stats YAML config file. + quant_log_dir: Base directory for quant stats logs (a rank subdirectory will be created). + rank: The global rank of this process. + layer_precision: Per-layer precision list (0-indexed by position). Each element is + ``"fp8"``, ``"fp4"``, or ``None``. + """ + import nvdlfw_inspect.api as debug_api + import transformer_engine + + # Derive 1-indexed layer lists for the debug API, which uses 1-indexed layer names. + fp8_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp8"] or None + fp4_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp4"] or None + updated_config = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=fp4_layers_1indexed, + fp8_layers=fp8_layers_1indexed, + ) + + rank_log_dir = Path(quant_log_dir) / f"rank_{rank}" + rank_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {rank_log_dir}") + + te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") + debug_api.initialize( + config_file=updated_config, + feature_dirs=[te_features_dir], + log_dir=rank_log_dir, + default_logging_enabled=True, + ) + + +def resolve_layer_precision( + num_layers: int, + fp8_enabled: bool, + fp4_enabled: bool, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, +) -> list[str | None]: + """Resolve layer-wise quantization assignments from user config. + + Takes 1-indexed layer lists (as specified by the user in YAML config) and returns a per-layer + precision list (0-indexed by position). When a quantization format is enabled but no layer list + is provided, all layers default to that format. When one format has explicit layers and the other + is enabled without a layer list, the unspecified format defaults to the remaining (unclaimed) layers. + + Args: + num_layers: Total number of transformer layers in the model. + fp8_enabled: Whether FP8 quantization is enabled. + fp4_enabled: Whether FP4 quantization is enabled. + fp8_layers: 1-indexed list of layers for FP8, or None if not specified. + fp4_layers: 1-indexed list of layers for FP4, or None if not specified. + + Returns: + A list of length ``num_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None`` + (BF16 fallback), indexed by layer position (0-indexed). + + Raises: + ValueError: If both formats are enabled with no layer lists, or if layer lists overlap. + """ + all_layers = set(range(1, num_layers + 1)) + + if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None: + raise ValueError( + "Both fp8_config and fp4_config are enabled but neither fp8_layers nor fp4_layers is specified. " + "When both are enabled, you must explicitly provide layer lists to indicate which layers use which format." + ) + + # When one format has explicit layers and the other defaults, fill in the remaining layers. + if fp8_enabled and fp8_layers is None: + claimed_by_fp4 = set(fp4_layers) if fp4_layers is not None else set() + fp8_layers = sorted(all_layers - claimed_by_fp4) + if claimed_by_fp4: + logger.warning( + f"fp8_config.enabled=True with no fp8_layers specified, but fp4_layers={sorted(claimed_by_fp4)} " + f"are already claimed by FP4. Defaulting FP8 to the remaining layers: {fp8_layers}" + ) + else: + logger.info( + f"fp8_config.enabled=True with no fp8_layers specified, defaulting all {num_layers} layers to FP8" + ) + + if fp4_enabled and fp4_layers is None: + claimed_by_fp8 = set(fp8_layers) if fp8_layers is not None else set() + fp4_layers = sorted(all_layers - claimed_by_fp8) + if claimed_by_fp8: + logger.warning( + f"fp4_config.enabled=True with no fp4_layers specified, but fp8_layers={sorted(claimed_by_fp8)} " + f"are already claimed by FP8. Defaulting FP4 to the remaining layers: {fp4_layers}" + ) + else: + logger.info( + f"fp4_config.enabled=True with no fp4_layers specified, defaulting all {num_layers} layers to FP4" + ) + + # Disable layer lists when corresponding config is not enabled. + if not fp8_enabled: + fp8_layers = None + if not fp4_enabled: + fp4_layers = None + + # Validate no overlap between FP8 and FP4 layer assignments. + if fp8_layers is not None and fp4_layers is not None: + overlap = set(fp8_layers) & set(fp4_layers) + if overlap: + raise ValueError( + f"fp8_layers and fp4_layers cannot have overlapping layer numbers. Found overlap: {sorted(overlap)}" + ) + + # Build per-layer precision list (0-indexed by position, 1-indexed for lookup). + fp8_set = set(fp8_layers) if fp8_layers is not None else set() + fp4_set = set(fp4_layers) if fp4_layers is not None else set() + return [ + "fp8" if layer_1indexed in fp8_set else "fp4" if layer_1indexed in fp4_set else None + for layer_1indexed in range(1, num_layers + 1) + ] diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt index 40f36f659d..8a15cec936 100644 --- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt +++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt @@ -5,9 +5,7 @@ torchao!=0.14.0 torchdata torchmetrics tqdm -transformer_engine[pytorch] transformers wandb zstandard nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect -pytest diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py index 08330b12f7..bb7a2d8ed6 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py @@ -56,6 +56,8 @@ def pytest_collection_modifyitems(items): stats_test_names = { "test_sanity_ddp_fp8_stats_logging", "test_sanity_fsdp2_fp8_stats_logging", + "test_sanity_ddp_fp8_partial_layers_stats_logging", + "test_sanity_fsdp2_fp8_partial_layers_stats_logging", } stats_tests = [item for item in items if item.name in stats_test_names] other_tests = [item for item in items if item.name not in stats_test_names] diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_mxfp8_fsdp2_checkpoint_resume.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_mxfp8_fsdp2_checkpoint_resume.py new file mode 100644 index 0000000000..f4aaf03404 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_mxfp8_fsdp2_checkpoint_resume.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal reproduction of FSDP2 + MXFP8 checkpoint resume crash. + +Bug: After fully_shard() wraps a model with quantized_model_init (MXFP8) params, +checkpoint resume via set_state_dict crashes with: + RuntimeError: Attempted to access the data pointer on an invalid python storage. + +Root cause: set_state_dict -> model.load_state_dict -> copy_() on MXFP8Tensor +re-quantizes, allocating new internal storage. FSDP2's reset_sharded_param +(post-load hook) then calls untyped_storage().data_ptr() on the invalidated +storage. PyTorch has a "# TODO: need to support tensor subclass" comment at +the crash site (_fsdp_param.py line 892). + +Fix: Wrap the data_ptr() comparison in try/except RuntimeError. When it fails, +treat as same_local_tensor=False so _sharded_param_data gets re-recorded. + +Run with: torchrun --nproc_per_node=2 test_mxfp8_fsdp2_checkpoint_resume.py +""" + +import argparse +import shutil + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import transformer_engine.pytorch as te +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam +from torch.distributed.tensor import DTensor +from torch.nn import functional as f_nn +from transformer_engine.common.recipe import MXFP8BlockScaling +from transformer_engine.pytorch.optimizers import FusedAdam +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + + +HIDDEN = 256 +FFN_HIDDEN = 1024 +NUM_HEADS = 8 +NUM_LAYERS = 2 +SEQ_LEN = 32 +BATCH = 2 + + +def apply_reset_sharded_param_fix(): + """Monkey-patch FSDPParam.reset_sharded_param to handle QuantizedTensor. + + After checkpoint load, copy_() on MXFP8Tensor re-quantizes which can + invalidate the old untyped_storage, causing data_ptr() to crash. + This wraps the comparison in try/except so reset_sharded_param can + proceed normally (re-recording _sharded_param_data). + """ + + def _patched_reset_sharded_param(self): + module_info = self._module_info + new_param = getattr(module_info.module, module_info.param_name) + if new_param is not self.sharded_param: + if torch.__future__.get_swap_module_params_on_conversion(): + raise AssertionError( + f"Expects swap_tensors to preserve object but got {new_param} instead of {self.sharded_param}" + ) + self.sharded_param = new_param + + local_tensor = new_param._local_tensor + if local_tensor.is_meta: + return + + updated_local_tensor = False + same_local_tensor = False + + if type(self._sharded_param_data) is torch.Tensor: + try: + same_local_tensor = ( + self._sharded_param_data.untyped_storage().data_ptr() > 0 + and self._sharded_param_data.untyped_storage().data_ptr() + == local_tensor.untyped_storage().data_ptr() + ) + except RuntimeError: + # QuantizedTensor (e.g. MXFP8Tensor) can have invalid storage + # after copy_() re-quantization. + same_local_tensor = False + + padded_sharded_size = self.padded_sharded_param_size + shard_dim = self.fsdp_placement.dim + length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 + + if local_tensor.size() != padded_sharded_size and not same_local_tensor: + if shard_dim != 0: + raise AssertionError(f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}") + padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) + padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(local_tensor) + local_tensor = padded_local_tensor + updated_local_tensor = True + + if self.pin_memory and not local_tensor.is_pinned(): + local_tensor = local_tensor.cpu().pin_memory() + updated_local_tensor = True + + if not same_local_tensor: + self._sharded_param_data = local_tensor.view(-1) + + if not isinstance(self.sharded_param, DTensor): + raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}") + + if updated_local_tensor: + self.sharded_param._local_tensor = local_tensor.narrow(dim=shard_dim, start=0, length=length) + if not self.sharded_param._local_tensor.is_contiguous(): + raise AssertionError("Expected sharded_param._local_tensor to be contiguous") + + self._sharding_spec = self.sharded_param._spec + + FSDPParam.reset_sharded_param = _patched_reset_sharded_param + + +def _save_custom_attrs(model): + """Save custom attrs on QuantizedTensor params (lost during fully_shard + reset_parameters).""" + attrs = {} + for name, param in model.named_parameters(): + local = param._local_tensor if isinstance(param, DTensor) else param + if isinstance(local, QuantizedTensor): + param_attrs = {} + for attr_name in dir(local): + if not attr_name.startswith("_") and not callable(getattr(local, attr_name, None)): + try: + param_attrs[attr_name] = getattr(local, attr_name) + except Exception: + pass + attrs[name] = param_attrs + return attrs + + +def _restore_custom_attrs(model, attrs): + """Restore custom attrs on QuantizedTensor params.""" + for name, param in model.named_parameters(): + if name in attrs: + local = param._local_tensor if isinstance(param, DTensor) else param + if isinstance(local, QuantizedTensor): + for attr_name, attr_val in attrs[name].items(): + try: + setattr(local, attr_name, attr_val) + except Exception: + pass + + +def build_model(recipe): + """Build model with quantized_model_init on meta device.""" + with te.quantized_model_init( + recipe=recipe, + enabled=True, + preserve_high_precision_init_val=True, + ): + model = torch.nn.Sequential( + *[ + te.TransformerLayer( + HIDDEN, + FFN_HIDDEN, + NUM_HEADS, + fuse_qkv_params=True, + params_dtype=torch.bfloat16, + hidden_dropout=0.0, + attention_dropout=0.0, + device="meta", + ) + for _ in range(NUM_LAYERS) + ] + ) + return model + + +def shard_model(model, mesh): + """Apply FSDP2 sharding, then materialize meta params via reset_parameters.""" + has_meta = any(p.is_meta for p in model.parameters()) + custom_attrs = _save_custom_attrs(model) + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + if has_meta: + for module in model.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + _restore_custom_attrs(model, custom_attrs) + return model + + +def build_and_shard(recipe, mesh, device): + """Build model, shard, create optimizer, run one step to populate optimizer state.""" + model = build_model(recipe) + model = shard_model(model, mesh) + + optimizer = FusedAdam( + model.parameters(), + lr=1e-3, + master_weights=True, + master_weight_dtype=torch.float32, + ) + + # Run one training step to populate optimizer state + x = torch.randn(SEQ_LEN, BATCH, HIDDEN, dtype=torch.bfloat16, device=device) + target = torch.randn_like(x) + optimizer.zero_grad(set_to_none=True) + with te.autocast(enabled=True, recipe=recipe): + out = model(x) + loss = f_nn.mse_loss(out, target) + loss.backward() + optimizer.step() + + return model, optimizer + + +def run(apply_fix: bool): + """Run the reproduction: save checkpoint, load it, verify forward pass.""" + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + rank = dist.get_rank() + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + world_size = dist.get_world_size() + mesh = DeviceMesh("cuda", list(range(world_size))) + + recipe = MXFP8BlockScaling() + + if apply_fix: + apply_reset_sharded_param_fix() + if rank == 0: + print("Applied reset_sharded_param fix") + + # Build model, train one step, save checkpoint + model, optimizer = build_and_shard(recipe, mesh, device) + if rank == 0: + print("Model built and trained for 1 step") + + # Record reference output + x = torch.randn(SEQ_LEN, BATCH, HIDDEN, dtype=torch.bfloat16, device=device) + with torch.no_grad(), te.autocast(enabled=True, recipe=recipe): + ref_output = model(x).clone() + if rank == 0: + print(f"Reference output recorded, norm={ref_output.norm().item():.4f}") + + checkpoint_dir = "/tmp/te_test_mxfp8_fsdp2_ckpt_resume" + if rank == 0: + shutil.rmtree(checkpoint_dir, ignore_errors=True) + dist.barrier() + + try: + # Save checkpoint + model_state = {k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state")} + dcp.save({"model": model_state, "optimizer": optimizer.state_dict()}, checkpoint_id=checkpoint_dir) + dist.barrier() + if rank == 0: + print(f"Checkpoint saved to {checkpoint_dir}") + + # Build fresh model + model2, optimizer2 = build_and_shard(recipe, mesh, device) + if rank == 0: + print("Fresh model built, loading checkpoint...") + + # Load checkpoint — THIS IS WHERE THE CRASH HAPPENS WITHOUT THE FIX + model2_state = {k: v for k, v in model2.state_dict().items() if not k.endswith("_extra_state")} + state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()} + dcp.load(state_to_load, checkpoint_id=checkpoint_dir) + model2.load_state_dict(state_to_load["model"], strict=False) + optimizer2.load_state_dict(state_to_load["optimizer"]) + dist.barrier() + if rank == 0: + print("Checkpoint loaded successfully!") + + # Verify output matches + with torch.no_grad(), te.autocast(enabled=True, recipe=recipe): + loaded_output = model2(x) + + torch.testing.assert_close( + loaded_output, + ref_output, + rtol=0, + atol=0, + msg=lambda m: f"Output mismatch after checkpoint load: {m}", + ) + if rank == 0: + print("Output parity verified — bitwise identical!") + + finally: + dist.barrier() + if rank == 0: + shutil.rmtree(checkpoint_dir, ignore_errors=True) + + dist.destroy_process_group() + if rank == 0: + print("SUCCESS" if apply_fix else "SUCCESS (unexpected — bug may be fixed upstream)") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix", action="store_true", help="Apply the reset_sharded_param monkey-patch fix") + args = parser.parse_args() + torch.manual_seed(42) + torch.cuda.manual_seed(42) + run(apply_fix=args.fix) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index aebdfe17ef..d919278d4a 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -34,7 +34,7 @@ def _make_args(logging_frequency=1, num_train_steps=100): "wandb": {"project": "test", "mode": "disabled"}, "num_train_steps": num_train_steps, "profiler": {"enabled": False}, - "fp8_stats_config": {"enabled": False}, + "quant_stats_config": {"enabled": False}, } ) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py new file mode 100644 index 0000000000..2d6e02b050 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py @@ -0,0 +1,332 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +from pathlib import Path + +import pytest +import yaml + + +sys.path.append(Path(__file__).parent.parent.as_posix()) + +from quantization import generate_layer_regex, resolve_layer_precision, update_quant_stats_config + + +# -- resolve_layer_precision -- + + +def test_fp8_enabled_no_layers_defaults_all(): + """When fp8 is enabled with no explicit layers, all layers should default to FP8.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp8", "fp8", "fp8", "fp8", "fp8", "fp8"] + + +def test_fp4_enabled_no_layers_defaults_all(): + """When fp4 is enabled with no explicit layers, all layers should default to FP4.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp4", "fp4", "fp4", "fp4", "fp4", "fp4"] + + +def test_fp8_explicit_layers(): + """Explicit 1-indexed fp8_layers should produce fp8 at those positions.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3, 5], fp4_layers=None + ) + assert result == ["fp8", None, "fp8", None, "fp8", None] + + +def test_fp4_explicit_layers(): + """Explicit 1-indexed fp4_layers should produce fp4 at those positions.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=[2, 4, 6] + ) + assert result == [None, "fp4", None, "fp4", None, "fp4"] + + +def test_mixed_fp8_fp4_explicit(): + """Both enabled with explicit non-overlapping layers should work correctly.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 3, 4], fp4_layers=[2, 5] + ) + assert result == ["fp8", "fp4", "fp8", "fp8", "fp4", None] + + +def test_both_enabled_no_layers_raises(): + """Both enabled with no layer lists should raise ValueError.""" + with pytest.raises(ValueError, match="Both fp8_config and fp4_config are enabled"): + resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None) + + +def test_overlapping_layers_raises(): + """Overlapping layer assignments should raise ValueError.""" + with pytest.raises(ValueError, match="fp8_layers and fp4_layers cannot have overlapping"): + resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=[3, 4, 5] + ) + + +def test_disabled_ignores_layers(): + """When a format is disabled, its layers should be ignored.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=[1, 2, 3], fp4_layers=[4, 5, 6] + ) + assert result == [None, None, None, None, None, None] + + +def test_both_disabled(): + """Both disabled with no layers should return all None.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == [None, None, None, None, None, None] + + +def test_large_model_defaults_all(): + """Auto-population should work correctly for larger models (e.g. 36 layers).""" + result = resolve_layer_precision( + num_layers=36, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result == ["fp8"] * 36 + + +def test_fp8_enabled_empty_list(): + """An explicit empty list should remain empty (not default to all).""" + result = resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[], fp4_layers=None) + assert result == [None, None, None, None, None, None] + + +def test_both_enabled_fp8_specified_fp4_defaults_to_remaining(): + """When both enabled, FP8 has explicit layers, FP4 should default to the remaining layers.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=None + ) + assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + + +def test_both_enabled_fp4_specified_fp8_defaults_to_remaining(): + """When both enabled, FP4 has explicit layers, FP8 should default to the remaining layers.""" + result = resolve_layer_precision( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=[4, 5, 6] + ) + assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + + +def test_returns_correct_length(): + """Result list length should always equal num_layers.""" + for n in [1, 6, 48]: + result = resolve_layer_precision( + num_layers=n, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert len(result) == n + + +# -- generate_layer_regex -- + + +def test_single_layer(): + """Single layer should produce a simple regex.""" + regex = generate_layer_regex([3]) + assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv") + assert not re.search(regex, "model.model.layers.2.self_attention.layernorm_qkv") + + +def test_multiple_layers(): + """Multiple layers should match any of them.""" + regex = generate_layer_regex([1, 2, 3]) + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.model.layers.2.layernorm_mlp.fc1") + assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.model.layers.4.self_attention.proj") + + +def test_matches_correct_sublayers(): + """Regex should only match layernorm_qkv, proj, fc1, fc2.""" + regex = generate_layer_regex([1]) + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv_something") + assert re.search(regex, "model.model.layers.1.self_attention.proj_something") + assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc1_something") + assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc2_something") + # Should not match unrelated sublayer names + assert not re.search(regex, "model.model.layers.1.self_attention.some_other_thing") + + +def test_none_returns_disabled_pattern(): + """None should return a pattern that matches nothing.""" + regex = generate_layer_regex(None) + assert "DISABLED" in regex + assert not re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + + +def test_empty_list_returns_disabled_pattern(): + """Empty list should return a pattern that matches nothing.""" + regex = generate_layer_regex([]) + assert "DISABLED" in regex + + +def test_1indexed_layer_names(): + """Regex should use 1-indexed layer numbers (matching debug API naming).""" + regex = generate_layer_regex([1]) + # Should match layers.1 (1-indexed first layer) + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + # Should NOT match layers.0 (0-indexed first layer) + assert not re.search(regex, "model.model.layers.0.self_attention.layernorm_qkv") + + +# -- update_quant_stats_config -- + + +@pytest.fixture +def fp8_only_config(tmp_path): + """Create an FP8-only stats config file.""" + config = { + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": { + "enabled": True, + "tensors_struct": [{"tensor": "activation", "stats": ["underflows%"], "freq": 10}], + } + }, + } + } + config_path = tmp_path / "fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + +@pytest.fixture +def fp4_fp8_config(tmp_path): + """Create a combined FP4+FP8 stats config file.""" + config = { + "example_fp4_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogNvfp4TensorStats": {"enabled": True}, + }, + }, + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": {"enabled": True}, + }, + }, + } + config_path = tmp_path / "fp4_fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + +def test_fp8_layers_updates_regex(fp8_only_config): + """FP8 layer list should update the regex in the output config.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3]) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.model.layers.4.self_attention.proj") + + +def test_none_layers_disables_matching(fp8_only_config): + """None layers should set regex to match nothing.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=None) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert "DISABLED" in regex + + +def test_fp4_section_disabled_fp8_still_updated(fp4_fp8_config): + """FP4 stats section should be disabled (not yet supported), FP8 should still be updated.""" + output_path = update_quant_stats_config(config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6]) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 regex should still match layers 4-6 + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj") + + +def test_original_file_not_modified(fp8_only_config): + """update_quant_stats_config should write to a temp file, not modify the original.""" + with open(fp8_only_config) as f: + original_content = f.read() + + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2]) + + assert output_path != fp8_only_config + with open(fp8_only_config) as f: + assert f.read() == original_content + + +def test_preserves_other_config_fields(fp8_only_config): + """Non-layer fields in the config should be preserved.""" + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1]) + with open(output_path) as f: + result = yaml.safe_load(f) + # The transformer_engine section should still be there + assert result["example_fp8_tensor_stat_collection"]["transformer_engine"]["LogFp8TensorStats"]["enabled"] is True + + +def test_missing_section_is_skipped(fp8_only_config): + """If fp4 section doesn't exist in config, it should be silently skipped.""" + # fp8_only_config has no fp4 section -- passing fp4_layers should not error + output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=[1, 2], fp8_layers=[3, 4]) + with open(output_path) as f: + result = yaml.safe_load(f) + # Only FP8 section should exist and be updated + assert "example_fp4_tensor_stat_collection" not in result + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv") + + +def test_with_real_fp4_config(): + """Test with the actual fp4_debugging_stats.yaml file.""" + config_path = Path(__file__).parent.parent / "fp4_debugging_stats.yaml" + if not config_path.exists(): + pytest.skip("fp4_debugging_stats.yaml not found") + + output_path = update_quant_stats_config(config_file=str(config_path), fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6]) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled (not yet supported in current TE release) + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 section should still be updated and working + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj") diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py new file mode 100644 index 0000000000..831bfe4ee8 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for quantized_model_init with FL1 (first/last layer BF16) and all-FP8. + +Verifies that: +1. All-FP8 + qinit: all decoder layer weights are QuantizedTensors with high-precision init vals +2. FL1 + qinit: FP8 layers have QuantizedTensor weights, BF16 layers have regular BF16 weights +3. BF16 layers don't lose precision from an outer quantized_model_init context + +Uses Float8BlockScaling instead of MXFP8BlockScaling so tests run on non-Blackwell GPUs. +The quantized_model_init behavior is recipe-agnostic. +""" + +import sys +from pathlib import Path + +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Float8BlockScaling +from transformer_engine.pytorch.tensor import QuantizedTensor + + +sys.path.append(Path(__file__).parent.parent.as_posix()) + +from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM + + +# Small model config for fast testing +_SMALL_CONFIG_KWARGS = { + "num_hidden_layers": 4, + "hidden_size": 256, + "intermediate_size": 512, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "vocab_size": 1024, + "max_position_embeddings": 128, +} + +requires_gpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _has_quantized_weights(layer) -> bool: + """Check if a TE TransformerLayer has any QuantizedTensor parameters.""" + for param in layer.parameters(): + if isinstance(param.data, QuantizedTensor): + return True + return False + + +def _has_high_precision_init_val(layer) -> bool: + """Check if any parameter in the layer has a high-precision init val.""" + for param in layer.parameters(): + if hasattr(param, "get_high_precision_init_val") and param.get_high_precision_init_val() is not None: + return True + return False + + +@requires_gpu +def test_all_fp8_qinit(): + """All layers FP8 with quantized_model_init: all weights should be QuantizedTensors.""" + recipe = Float8BlockScaling() + config = NVLlamaConfig( + **_SMALL_CONFIG_KWARGS, + attn_input_format="bshd", + dtype=torch.bfloat16, + ) + + with te.quantized_model_init(recipe=recipe, enabled=True, preserve_high_precision_init_val=True): + model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + + # All decoder layers should have quantized weights + for i, layer in enumerate(model.model.layers): + assert _has_quantized_weights(layer), f"Layer {i} should have QuantizedTensor weights" + assert _has_high_precision_init_val(layer), f"Layer {i} should have high-precision init vals" + + +@requires_gpu +def test_fl1_qinit_bf16_layers_not_quantized(): + """FL1 + qinit: BF16 layers (first/last) should NOT have quantized weights.""" + recipe = Float8BlockScaling() + # FL1: layers 2,3 in FP8 (1-indexed), layers 1,4 in BF16 + layer_precision = [None, "fp8", "fp8", None] + config = NVLlamaConfig( + **_SMALL_CONFIG_KWARGS, + attn_input_format="bshd", + dtype=torch.bfloat16, + layer_precision=layer_precision, + use_quantized_model_init=True, + ) + + with te.quantized_model_init(recipe=recipe, enabled=True, preserve_high_precision_init_val=True): + model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + + # BF16 layers (0 and 3, 0-indexed) should NOT have quantized weights + assert not _has_quantized_weights(model.model.layers[0]), "First layer (BF16) should not have QuantizedTensors" + assert not _has_quantized_weights(model.model.layers[3]), "Last layer (BF16) should not have QuantizedTensors" + + # FP8 layers (1 and 2, 0-indexed) should have quantized weights + assert _has_quantized_weights(model.model.layers[1]), "FP8 layer 1 should have QuantizedTensors" + assert _has_quantized_weights(model.model.layers[2]), "FP8 layer 2 should have QuantizedTensors" + + +@requires_gpu +def test_fl1_qinit_fp8_layers_preserve_high_precision(): + """FL1 + qinit: FP8 layers should preserve high-precision init vals for master weights.""" + recipe = Float8BlockScaling() + layer_precision = [None, "fp8", "fp8", None] + config = NVLlamaConfig( + **_SMALL_CONFIG_KWARGS, + attn_input_format="bshd", + dtype=torch.bfloat16, + layer_precision=layer_precision, + use_quantized_model_init=True, + ) + + with te.quantized_model_init(recipe=recipe, enabled=True, preserve_high_precision_init_val=True): + model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + + # FP8 layers should have high-precision init values + assert _has_high_precision_init_val(model.model.layers[1]), "FP8 layer should have high-precision init vals" + assert _has_high_precision_init_val(model.model.layers[2]), "FP8 layer should have high-precision init vals" + + # BF16 layers should NOT have high-precision init values (they're already BF16) + assert not _has_high_precision_init_val(model.model.layers[0]), ( + "BF16 layer should not have high-precision init vals" + ) + assert not _has_high_precision_init_val(model.model.layers[3]), ( + "BF16 layer should not have high-precision init vals" + ) + + +@requires_gpu +def test_fl1_no_qinit_baseline(): + """FL1 without qinit: all weights should be regular BF16 tensors (baseline).""" + recipe = Float8BlockScaling() + layer_precision = [None, "fp8", "fp8", None] + config = NVLlamaConfig( + **_SMALL_CONFIG_KWARGS, + attn_input_format="bshd", + dtype=torch.bfloat16, + layer_precision=layer_precision, + ) + + # No quantized_model_init context — default behavior + model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + + # No layers should have quantized weights + for i, layer in enumerate(model.model.layers): + assert not _has_quantized_weights(layer), f"Layer {i} should not have QuantizedTensors without qinit" diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index 89e85068de..0ddcbf1a60 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -439,6 +439,59 @@ def test_sanity_fsdp2_cp(tmp_path, recipe_path): assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" +def test_sanity_convergence_fsdp2_te_fused_adam(tmp_path, recipe_path): + """Test FSDP2 training with TE FusedAdam for FP32 master weights. + + This test validates: + - FusedAdam optimizer initializes correctly with FSDP2-wrapped model + - Training converges with FP32 master weights maintained by FusedAdam + - FusedAdam is a drop-in replacement for the MixedPrecisionPolicy approach + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_fp32_master_weights_fused=true", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + +def test_sanity_convergence_fsdp2_te_fused_adam_fp8(tmp_path, recipe_path): + """Test FSDP2 + FusedAdam + FP8 training. + + This test validates FusedAdam works correctly alongside FP8 quantization, + matching the approach used in the lingua 7B MXFP8 experiment config. + """ + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "use_fp32_master_weights_fused=true", + "fp8_config.enabled=true", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + ], + ) + + final_loss = main_fsdp2(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + @requires_fp8 def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path): """Test that FP8 stats logging creates the expected log files.""" @@ -452,8 +505,8 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path): f"checkpoint.ckpt_dir={tmp_path}", "+dataset.pad_sequences_to_be_divisible_by=16", "fp8_config.enabled=true", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) @@ -493,8 +546,8 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): f"checkpoint.ckpt_dir={tmp_path}", "fp8_config.enabled=true", "+dataset.pad_sequences_to_be_divisible_by=16", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) @@ -507,6 +560,65 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists() +@requires_fp8 +def test_sanity_ddp_fp8_partial_layers_stats_logging(tmp_path, recipe_path): + """Test DDP training with layer-wise FP8 stats (layers 1-3 only).""" + quant_log_dir = tmp_path / "quant_stats_logs" + + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "+dataset.pad_sequences_to_be_divisible_by=16", + "fp8_config.enabled=true", + "fp8_layers=[1,2,3]", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={quant_log_dir}", + "num_train_steps=4", + ], + ) + + main_ddp(sanity_config) + + # Verify the log directory structure was created + assert quant_log_dir.exists(), "Quant log directory was not created" + assert (quant_log_dir / "rank_0").exists(), "rank_0 directory was not created" + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs").exists(), "nvdlfw_inspect_logs directory was not created" + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs").exists(), ( + "nvdlfw_inspect_statistics_logs directory was not created" + ) + + +@requires_fp8 +def test_sanity_fsdp2_fp8_partial_layers_stats_logging(tmp_path, recipe_path): + """Test FSDP2 training with layer-wise FP8 stats (layers 1-3 only).""" + quant_log_dir = tmp_path / "quant_stats_logs" + + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"+wandb_init_args.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "+dataset.pad_sequences_to_be_divisible_by=16", + "fp8_config.enabled=true", + "fp8_layers=[1,2,3]", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={quant_log_dir}", + "num_train_steps=4", + ], + ) + + main_fsdp2(sanity_config) + + # Verify log structure + assert quant_log_dir.exists() + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists() + assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists() + + def run_train_cmd(cmd, recipe_path): """Run a training command and check for errors. diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 413b9262c7..548931be67 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -42,9 +42,9 @@ from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_layer_precision from scheduler import get_cosine_annealing_schedule_with_warmup @@ -66,37 +66,81 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - if args.fp8_stats_config.enabled: - initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) + if args.use_fp32_master_weights: + raise ValueError("FP32 master weights are not supported with DDP. Use train_fsdp2.py instead.") # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2. device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",)) + if args.use_te: + config_class = NVLlamaConfig + model_class = NVLlamaForCausalLM + else: + config_class = LlamaConfig + model_class = LlamaForCausalLM + # --- Model Configuration --- - # Create quantization recipes -- only used if FP8/FP4 is enabled in the config. + config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + + # Resolve layer-wise quantization assignments and store on config. + layer_precision = resolve_layer_precision( + num_layers=config.num_hidden_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + config.layer_precision = layer_precision + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + ) + + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. fp8_recipe = None + fp4_recipe = None if args.fp8_config.enabled: fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) - - fp4_recipe = None if args.fp4_config.enabled: - fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs) + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + if args.fp8_config.quantized_model_init_kwargs.get("enabled", False) and not ( + args.fp8_config.enabled or args.fp4_config.enabled + ): + raise ValueError( + "fp8_config.quantized_model_init_kwargs.enabled=true requires fp8_config.enabled=true or " + "fp4_config.enabled=true. Enable at least one quantization format to use quantized model initialization." + ) # --- Model Initialization --- - if args.use_te: - config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) - model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) - else: - config = LlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) - model = LlamaForCausalLM(config) + # Optionally use transformer engine to initialize only fp8 versions of weights by setting + # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 + # and fp8 versions of weights are kept. + with transformer_engine.pytorch.quantized_model_init( + recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs + ): + model = ( + model_class(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + if model_class is NVLlamaForCausalLM + else model_class(config) + ) logger.info("Initialized Model:\n%s", model) + # Attach quantization recipes to the model (layer precision is already on config). + if isinstance(model, NVLlamaForCausalLM): + model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + # --- Distributed Wrapping (DDP) --- - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) model = model.to(device=device) @@ -157,9 +201,8 @@ def main(args: DictConfig) -> float | None: micro_step += 1 # DDP requires no_sync to skip all-reduce until the last microbatch in the accumulation window. with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext(): - # Forward pass with mixed precision. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - outputs = model(**batch) + # Forward pass - quantization autocast is handled inside the model via set_recipes(). + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index da19daa2a7..803c0e2969 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -33,9 +33,11 @@ import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.tensor import DTensor from torch.optim import AdamW from transformer_engine.common.recipe import Format +from transformer_engine.pytorch.optimizers import FusedAdam from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM @@ -48,9 +50,9 @@ ) from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig -from fp8_debugging import initialize_fp8_debugging from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_layer_precision from scheduler import get_cosine_annealing_schedule_with_warmup @@ -58,6 +60,97 @@ logger.setLevel(logging.INFO) +def _log_per_layer_gradient_norms(model: torch.nn.Module, step: int) -> dict[str, float]: + """Log per-layer gradient L2 norms, weight norms, and zero-gradient fractions. + + Debugging tool for FP8 gradient underflow. Groups parameters by decoder layer, + embed, and lm_head, then computes grad_norm, weight_norm, and the fraction of + gradient elements that are exactly zero (underflow indicator). + + Args: + model: The model to inspect (must have gradients populated, before optimizer.step). + step: Current training step (for logging context). + + Returns: + Dictionary of metric_name -> value, ready for wandb.log(). + """ + metrics: dict[str, float] = {} + layer_groups: dict[str, list[tuple[str, torch.nn.Parameter]]] = {} + + for name, param in model.named_parameters(): + if param.grad is None: + continue + if "model.layers." in name: + parts = name.split(".") + idx = parts[parts.index("layers") + 1] + group = f"layer_{idx}" + elif "embed_tokens" in name: + group = "embed" + elif "lm_head" in name: + group = "lm_head" + elif "norm" in name and "layers" not in name: + group = "final_norm" + else: + group = "other" + layer_groups.setdefault(group, []).append((name, param)) + + for group, params in sorted(layer_groups.items()): + grad_sq, weight_sq, total_el, zero_el = 0.0, 0.0, 0, 0 + for _name, param in params: + grad = param.grad + if isinstance(grad, DTensor): + grad = grad.to_local() + local_param = param._local_tensor if isinstance(param, DTensor) else param + g = grad.float().flatten() + grad_sq += (g * g).sum().item() + weight_sq += (local_param.float().flatten() ** 2).sum().item() + total_el += g.numel() + zero_el += (g == 0).sum().item() + + metrics[f"grad_debug/{group}/grad_norm"] = grad_sq**0.5 + metrics[f"grad_debug/{group}/weight_norm"] = weight_sq**0.5 + metrics[f"grad_debug/{group}/grad_zero_frac"] = zero_el / max(total_el, 1) + + return metrics + + +def _init_master_weights_from_high_precision( + optimizer: FusedAdam, model: torch.nn.Module, device: torch.device +) -> None: + """Initialize optimizer master weights from high-precision init values. + + When quantized_model_init is used with preserve_high_precision_init_val=True, each FP8 parameter + stores the original BF16 init values in CPU memory. This function initializes optimizer state + for all parameters, then overwrites master weights for quantized params with the preserved + high-precision values instead of dequantized FP8 values. + + Follows the TE example: + https://github.com/NVIDIA/TransformerEngine/blob/main/examples/pytorch/quantized_model_init/fully_shard.py + """ + count = 0 + for name, param in model.named_parameters(): + # Eagerly initialize optimizer state for all parameters. + # TE main's FusedAdam handles DTensor + QuantizedTensor natively. + optimizer.initialize_state(param, store_param_remainders=False) + + # For quantized params, overwrite master weights with the preserved high-precision + # init values (instead of the dequantized FP8 values set by initialize_state). + local = param._local_tensor if isinstance(param, DTensor) else param + if hasattr(local, "get_high_precision_init_val"): + hp_val = local.get_high_precision_init_val() + if hp_val is not None: + optimizer.set_scaled_state(param, "master_param", hp_val.to(device=device, dtype=torch.float32)) + local.clear_high_precision_init_val() + count += 1 + logger.debug("Seeded master weight for %s from high-precision init val", name) + if count > 0: + logger.info("Initialized %d master weight(s) from high-precision init values", count) + else: + logger.info( + "No parameters with high-precision init values found (quantized_model_init may not have been used)" + ) + + @hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") def main(args: DictConfig) -> float | None: """Train Llama3 with TE layers using FSDP2. @@ -72,41 +165,100 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled: - initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) - device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",)) + if args.use_te: + config_class = NVLlamaConfig + model_class = NVLlamaForCausalLM + else: + config_class = LlamaConfig + model_class = LlamaForCausalLM + # --- Model Configuration --- - # Create quantization recipes -- only used if FP8/FP4 is enabled in the config. + config = config_class.from_pretrained( + args.config_name_or_path, + dtype=torch.bfloat16, + **args.config_kwargs, + ) + + # Resolve layer-wise quantization assignments and store on config. + layer_precision = resolve_layer_precision( + num_layers=config.num_hidden_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + config.layer_precision = layer_precision + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + ) + + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. fp8_recipe = None + fp4_recipe = None if args.fp8_config.enabled: fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) - - fp4_recipe = None if args.fp4_config.enabled: - fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs) + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) + + if args.fp8_config.quantized_model_init_kwargs.get("enabled", False) and not ( + args.fp8_config.enabled or args.fp4_config.enabled + ): + raise ValueError( + "fp8_config.quantized_model_init_kwargs.enabled=true requires fp8_config.enabled=true or " + "fp4_config.enabled=true. Enable at least one quantization format to use quantized model initialization." + ) # --- Model Initialization --- - if args.use_te: - config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) - with torch.device("meta") if args.use_meta_device else nullcontext(): - model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) - else: - config = LlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) - with torch.device("meta") if args.use_meta_device else nullcontext(): - model = LlamaForCausalLM(config) + # Optionally use transformer engine to initialize only fp8 versions of weights by setting + # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 + # and fp8 versions of weights are kept. + with ( + torch.device("meta") if args.use_meta_device else nullcontext(), + transformer_engine.pytorch.quantized_model_init( + recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs + ), + ): + model = ( + model_class(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + if model_class is NVLlamaForCausalLM + else model_class(config) + ) logger.info("Initialized Model:\n%s", model) + def _log_memory(tag: str) -> None: + """Log GPU memory stats.""" + alloc = torch.cuda.memory_allocated() / (1024**3) + reserved = torch.cuda.memory_reserved() / (1024**3) + peak = torch.cuda.max_memory_allocated() / (1024**3) + logger.info("[Memory: %s] allocated=%.2f GB, reserved=%.2f GB, peak=%.2f GB", tag, alloc, reserved, peak) + + _log_memory("after_model_init") + # --- Distributed Wrapping (FSDP2) --- + mp_policy = MixedPrecisionPolicy() + # Each decoder layer should be individually sharded before sharding the full model. for layer in model.model.layers: - fully_shard(layer, mesh=device_mesh["dp"]) - fully_shard(model, mesh=device_mesh["dp"]) + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + + _log_memory("after_fsdp_wrap") + + # Attach quantization recipes to the model (layer precision is already on config). + if isinstance(model, NVLlamaForCausalLM): + model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. if args.use_meta_device: @@ -118,13 +270,22 @@ def main(args: DictConfig) -> float | None: model.apply(model._init_weights) # Assign names to layers so debug API can identify them - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) # --- Optimizer & Scheduler --- # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). - optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + adamw_kwargs = OmegaConf.to_container(args.adamw_kwargs, resolve=True) + if args.use_fp32_master_weights_fused: + # TE FusedAdam maintains FP32 master copies of BF16 params internally. + # 'fused' kwarg is not used by TE's FusedAdam (it's always fused). + adamw_kwargs.pop("fused", None) + optimizer = FusedAdam(model.parameters(), master_weights=True, **adamw_kwargs) # type: ignore + logger.info("Using TE FusedAdam with FP32 master weights") + else: + optimizer = AdamW(model.parameters(), **adamw_kwargs) # type: ignore scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + _log_memory("after_optimizer_init") if args.use_torch_compile: # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. @@ -140,21 +301,32 @@ def main(args: DictConfig) -> float | None: ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None if args.checkpoint.resume_from_checkpoint and ckpt_path: logger.info("Attempting to load checkpoint from %s", ckpt_path) - model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2( + model, optimizer, scheduler, _dl, start_step, epoch = load_checkpoint_fsdp2( model=model, optimizer=optimizer, scheduler=scheduler, ckpt_path=ckpt_path, dist_config=dist_config, - dataloader=train_dataloader, + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, process_group=device_mesh.get_group("dp"), ) + if _dl is not None: + train_dataloader = _dl logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: logger.info("No checkpoint to load, starting from scratch") start_step = 0 epoch = 0 + # When starting from scratch with quantized_model_init + preserve_high_precision_init_val, + # seed FP32 master weights from the original high-precision init values (not dequantized FP8). + # Skip on resume — checkpoint already has correct master weights, and eager dequantize() can + # invalidate QuantizedTensor storage causing FSDP2 forward failures. + if args.use_fp32_master_weights_fused and args.fp8_config.quantized_model_init_kwargs.get( + "preserve_high_precision_init_val", False + ): + _init_master_weights_from_high_precision(optimizer, model, device) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() @@ -165,14 +337,27 @@ def main(args: DictConfig) -> float | None: step = start_step micro_step = 0 # Gradient accumulation step counter while step < args.num_train_steps: - for batch in train_dataloader: + try: + dataloader_iter = iter(train_dataloader) + except ValueError as e: + if "last_yielded_worker_id does not match" in str(e): + # StatefulDataLoader's naive fast-forward replayed all items but ended on a + # different worker than saved — the streaming IterableDataset is non-deterministic + # across restarts (tokenize-with-windowing produces variable items per document). + # Clear the saved state and restart the dataloader from the beginning of the stream. + logger.warning("Dataloader state incompatible after fast-forward (%s), restarting from scratch.", e) + train_dataloader.next_iter_state = None + train_dataloader._iterator = None + dataloader_iter = iter(train_dataloader) + else: + raise + for batch in dataloader_iter: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 micro_step += 1 - # Forward pass with mixed precision. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - outputs = model(**batch) + # Forward pass - quantization autocast is handled inside the model via set_recipes(). + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps @@ -188,6 +373,14 @@ def main(args: DictConfig) -> float | None: # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + # Per-layer gradient debug logging (before optimizer.step clears grads). + if args.gradient_debug.enabled and step % args.gradient_debug.log_every_n_steps == 0: + grad_metrics = _log_per_layer_gradient_norms(model, step) + if dist_config.is_main_process(): + import wandb as _wandb + + _wandb.log(grad_metrics, step=step) + # Step optimizer. optimizer.step() scheduler.step() @@ -220,6 +413,7 @@ def main(args: DictConfig) -> float | None: # Dataloader exhausted, incrementing epoch epoch += 1 + logger.warning("Dataloader exhausted at step %s, incrementing epoch to %s", step, epoch) dataset_or_sampler.set_epoch(epoch) # --- Cleanup --- diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index eaf1a1b39f..fefd14367b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -29,14 +29,17 @@ from pathlib import Path import hydra +import nvdlfw_inspect.api as debug_api import nvtx import torch import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.tensor import DTensor from torch.optim import AdamW from transformer_engine.common.recipe import Format +from transformer_engine.pytorch.optimizers import FusedAdam from checkpoint import ( _ckpt_futures, @@ -50,6 +53,7 @@ from distributed_config import DistributedConfig from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_layer_precision from scheduler import get_cosine_annealing_schedule_with_warmup @@ -57,6 +61,31 @@ logger.setLevel(logging.INFO) +def _init_master_weights_from_high_precision( + optimizer: FusedAdam, model: torch.nn.Module, device: torch.device +) -> None: + """Initialize optimizer master weights from high-precision init values. + + When quantized_model_init is used with preserve_high_precision_init_val=True, each FP8 parameter + stores the original BF16 init values in CPU memory. This function initializes optimizer state + for all parameters, then overwrites master weights for quantized params with the preserved + high-precision values instead of dequantized FP8 values. + """ + count = 0 + for name, param in model.named_parameters(): + optimizer.initialize_state(param, store_param_remainders=False) + local = param._local_tensor if isinstance(param, DTensor) else param + if hasattr(local, "get_high_precision_init_val"): + hp_val = local.get_high_precision_init_val() + if hp_val is not None: + optimizer.set_scaled_state(param, "master_param", hp_val.to(device=device, dtype=torch.float32)) + local.clear_high_precision_init_val() + count += 1 + logger.debug("Seeded master weight for %s from high-precision init val", name) + if count > 0: + logger.info("Initialized %d master weight(s) from high-precision init values", count) + + @hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2") def main(args: DictConfig) -> float | None: """Train Llama3 with TE layers using FSDP2 with Context Parallelism. @@ -79,33 +108,85 @@ def main(args: DictConfig) -> float | None: logger.info("Created device mesh: %s", device_mesh) # --- Model Configuration --- - # Create quantization recipes -- only used if FP8/FP4 is enabled in the config. + config = NVLlamaConfig.from_pretrained( + args.config_name_or_path, + dtype=torch.bfloat16, + **args.config_kwargs, + ) + + # Resolve layer-wise quantization assignments and store on config. + layer_precision = resolve_layer_precision( + num_layers=config.num_hidden_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + config.layer_precision = layer_precision + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + layer_precision=layer_precision, + ) + + # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config. fp8_recipe = None + fp4_recipe = None if args.fp8_config.enabled: fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) - - fp4_recipe = None if args.fp4_config.enabled: - fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs) + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) - # --- Model Initialization --- - config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + if args.fp8_config.quantized_model_init_kwargs.get("enabled", False) and not ( + args.fp8_config.enabled or args.fp4_config.enabled + ): + raise ValueError( + "fp8_config.quantized_model_init_kwargs.enabled=true requires fp8_config.enabled=true or " + "fp4_config.enabled=true. Enable at least one quantization format to use quantized model initialization." + ) - with torch.device("meta") if args.use_meta_device else nullcontext(): + # --- Model Initialization --- + # Optionally use transformer engine to initialize only fp8 versions of weights by setting + # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 + # and fp8 versions of weights are kept. + with ( + torch.device("meta") if args.use_meta_device else nullcontext(), + transformer_engine.pytorch.quantized_model_init( + recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs + ), + ): model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) logger.info("Initialized Model:\n%s", model) + def _log_memory(tag: str) -> None: + """Log GPU memory stats.""" + alloc = torch.cuda.memory_allocated() / (1024**3) + reserved = torch.cuda.memory_reserved() / (1024**3) + peak = torch.cuda.max_memory_allocated() / (1024**3) + logger.info("[Memory: %s] allocated=%.2f GB, reserved=%.2f GB, peak=%.2f GB", tag, alloc, reserved, peak) + + _log_memory("after_model_init") + # --- Distributed Wrapping (FSDP2 + CP) --- cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") + mp_policy = MixedPrecisionPolicy() + # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. # Each decoder layer should be individually sharded before sharding the full model. for layer in model.model.layers: - fully_shard(layer, mesh=cp_dp_mesh) - fully_shard(model, mesh=cp_dp_mesh) + fully_shard(layer, mesh=cp_dp_mesh, mp_policy=mp_policy) + fully_shard(model, mesh=cp_dp_mesh, mp_policy=mp_policy) + + _log_memory("after_fsdp_wrap") # Attach the CP group to the model. for layer in model.model.layers: @@ -115,13 +196,28 @@ def main(args: DictConfig) -> float | None: torch.cuda.Stream(), ) + # Attach quantization recipes to the model (layer precision is already on config). + model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + if args.use_meta_device: # TE layers require special handling to initialize the weights from the meta device. model.init_empty_weights() + # Assign names to layers so debug API can identify them + if args.quant_stats_config.enabled: + debug_api.infer_and_assign_layer_names(model) + # --- Optimizer & Scheduler --- # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). - optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + adamw_kwargs = OmegaConf.to_container(args.adamw_kwargs, resolve=True) + if args.use_fp32_master_weights_fused: + # TE FusedAdam maintains FP32 master copies of BF16 params internally. + # 'fused' kwarg is not used by TE's FusedAdam (it's always fused). + adamw_kwargs.pop("fused", None) + optimizer = FusedAdam(model.parameters(), master_weights=True, **adamw_kwargs) # type: ignore + logger.info("Using TE FusedAdam with FP32 master weights") + else: + optimizer = AdamW(model.parameters(), **adamw_kwargs) # type: ignore scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) if args.use_torch_compile: @@ -177,6 +273,11 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 + if args.use_fp32_master_weights_fused and args.fp8_config.quantized_model_init_kwargs.get( + "preserve_high_precision_init_val", False + ): + _init_master_weights_from_high_precision(optimizer, model, device) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() @@ -192,10 +293,9 @@ def main(args: DictConfig) -> float | None: micro_step += 1 - # Forward pass with mixed precision. + # Forward pass - quantization autocast is handled inside the model via set_recipes(). with nvtx.annotate("Forward pass", color="green"): - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - outputs = model(**batch) + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py index 4555c1762a..058b625bec 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py @@ -341,6 +341,19 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) + def state_dict(self) -> dict: + """Delegate to the underlying HF IterableDataset's state tracking. + + This enables StatefulDataLoader to save/restore the stream position instead of + falling back to naive fast-forward (which crashes on cross-process restarts due to + last_yielded_worker_id mismatch with non-deterministic streaming datasets). + """ + return self.dataset.state_dict() + + def load_state_dict(self, state_dict: dict) -> None: + """Restore the underlying HF IterableDataset's stream position.""" + self.dataset.load_state_dict(state_dict) + @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py index 62171cd237..994c4f876b 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py @@ -58,6 +58,7 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + layer_precision: list[str | None] | None = None def __init__( self, @@ -223,11 +224,54 @@ def _init_method(x): self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + def set_recipes( + self, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ) -> None: + """Attach quantization recipe objects for per-layer autocast. + + Recipes are not serializable and must be set at runtime after model creation + and sharding (FSDP/DDP) but before training. The per-layer precision + assignments are read from ``self.config.layer_precision``. + + Args: + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + def forward( self, input_ids: torch.Tensor | None = None, @@ -304,12 +348,14 @@ def forward( if te_rope_emb.dtype != torch.float32: warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning) - with self.get_autocast_context(None, outer=True): - for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled + # by get_layer_autocast(), which nests inside this context. + with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe): + for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): if output_hidden_states: all_hidden_states = (*all_hidden_states, hidden_states) - with self.get_autocast_context(layer_idx): + with self.get_layer_autocast(layer_number): hidden_states = decoder_layer( hidden_states, attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, @@ -369,8 +415,12 @@ def get_autocast_context( if init and self.config.use_quantized_model_init: if precision in ("fp8", "fp4"): - return transformer_engine.pytorch.quantized_model_init(recipe=recipe) - return nullcontext() + # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext() + # preserves the outer context's settings (recipe, preserve_high_precision_init_val). + # A nested quantized_model_init would override preserve_high_precision_init_val to False. + return nullcontext() + # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context. + return transformer_engine.pytorch.quantized_model_init(enabled=False) if precision == "fp8": if recipe is None: @@ -597,8 +647,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor): updated_key_cache = key_cache.index_select(0, beam_idx) updated_value_cache = value_cache.index_select(0, beam_idx) self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) - - @property - def is_compileable(self) -> bool: - """Return False as this cache is not compatible with torch.compile.""" - return False From 3b71e25474f295f34250c1ac1be7ad5c9084306b Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Tue, 12 May 2026 20:15:12 +0000 Subject: [PATCH 02/10] Rename use_fp32_master_weights_fused to use_fp32_master_weights and remove dead code Simplify the config parameter name since FusedAdam is now the only FP32 master weights strategy. Also remove a stale guard in train_ddp.py and a duplicate config line in L2_lingua_7b_pure_bf16.yaml. Co-Authored-By: Claude Opus 4.6 --- .../recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml | 2 +- .../recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml | 2 +- .../llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml | 1 - .../recipes/llama3_native_te/hydra_config/defaults.yaml | 2 +- bionemo-recipes/recipes/llama3_native_te/tests/test_train.py | 4 ++-- bionemo-recipes/recipes/llama3_native_te/train_ddp.py | 3 --- bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py | 4 ++-- bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py | 4 ++-- 8 files changed, 9 insertions(+), 13 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml index 15069fde62..84196961cc 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml @@ -21,7 +21,7 @@ use_sequence_packing: false use_meta_device: true # FP32 master weights via TE FusedAdam -use_fp32_master_weights_fused: true +use_fp32_master_weights: true wandb: name: lingua-70b-bf16-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml index 7ec8784a9c..d0a78abc16 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml @@ -12,7 +12,7 @@ config_kwargs: use_sequence_packing: true # FP32 master weights via TE FusedAdam (recommended over MixedPrecisionPolicy) -use_fp32_master_weights_fused: true +use_fp32_master_weights: true wandb: name: lingua-7b-bf16 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml index 6defc71b88..152574bc87 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml @@ -5,7 +5,6 @@ defaults: - _self_ # No FP32 master weights at all -use_fp32_master_weights_fused: null use_fp32_master_weights: null wandb: diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index 0c20326dbf..9680140ddf 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -86,7 +86,7 @@ quant_stats_config: # Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. fp8_layers: null fp4_layers: null -use_fp32_master_weights_fused: null # Use TE FusedAdam for FP32 master weights +use_fp32_master_weights: null # Use TE FusedAdam for FP32 master weights gradient_debug: enabled: false diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index 0ddcbf1a60..c73c739fd6 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -454,7 +454,7 @@ def test_sanity_convergence_fsdp2_te_fused_adam(tmp_path, recipe_path): f"+wandb.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", "checkpoint.resume_from_checkpoint=false", - "use_fp32_master_weights_fused=true", + "use_fp32_master_weights=true", ], ) @@ -478,7 +478,7 @@ def test_sanity_convergence_fsdp2_te_fused_adam_fp8(tmp_path, recipe_path): f"+wandb.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", "checkpoint.resume_from_checkpoint=false", - "use_fp32_master_weights_fused=true", + "use_fp32_master_weights=true", "fp8_config.enabled=true", "use_sequence_packing=true", "config_kwargs.attn_input_format=thd", diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 548931be67..7d589045f1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -66,9 +66,6 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - if args.use_fp32_master_weights: - raise ValueError("FP32 master weights are not supported with DDP. Use train_fsdp2.py instead.") - # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2. device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",)) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 803c0e2969..5de7182fd1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -276,7 +276,7 @@ def _log_memory(tag: str) -> None: # --- Optimizer & Scheduler --- # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). adamw_kwargs = OmegaConf.to_container(args.adamw_kwargs, resolve=True) - if args.use_fp32_master_weights_fused: + if args.use_fp32_master_weights: # TE FusedAdam maintains FP32 master copies of BF16 params internally. # 'fused' kwarg is not used by TE's FusedAdam (it's always fused). adamw_kwargs.pop("fused", None) @@ -322,7 +322,7 @@ def _log_memory(tag: str) -> None: # seed FP32 master weights from the original high-precision init values (not dequantized FP8). # Skip on resume — checkpoint already has correct master weights, and eager dequantize() can # invalidate QuantizedTensor storage causing FSDP2 forward failures. - if args.use_fp32_master_weights_fused and args.fp8_config.quantized_model_init_kwargs.get( + if args.use_fp32_master_weights and args.fp8_config.quantized_model_init_kwargs.get( "preserve_high_precision_init_val", False ): _init_master_weights_from_high_precision(optimizer, model, device) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index fefd14367b..d7bf4ac9f6 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -210,7 +210,7 @@ def _log_memory(tag: str) -> None: # --- Optimizer & Scheduler --- # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). adamw_kwargs = OmegaConf.to_container(args.adamw_kwargs, resolve=True) - if args.use_fp32_master_weights_fused: + if args.use_fp32_master_weights: # TE FusedAdam maintains FP32 master copies of BF16 params internally. # 'fused' kwarg is not used by TE's FusedAdam (it's always fused). adamw_kwargs.pop("fused", None) @@ -273,7 +273,7 @@ def _log_memory(tag: str) -> None: start_step = 0 epoch = 0 - if args.use_fp32_master_weights_fused and args.fp8_config.quantized_model_init_kwargs.get( + if args.use_fp32_master_weights and args.fp8_config.quantized_model_init_kwargs.get( "preserve_high_precision_init_val", False ): _init_master_weights_from_high_precision(optimizer, model, device) From f68ff4ed4c73a1f51eced2c7c3905513d22b7d86 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Wed, 13 May 2026 22:04:42 +0000 Subject: [PATCH 03/10] Revert TokenPackingDataset state_dict/load_state_dict Adding state_dict/load_state_dict causes StatefulDataLoader to switch from fast-forward replay to stateful restore, which produces incorrect batches because the packing generator state is not serializable. Fast-forward replay works correctly with the deterministic shuffle seed. Signed-off-by: Savitha Srinivasan --- bionemo-recipes/models/esm2/collator.py | 13 ------------- bionemo-recipes/models/llama3/collator.py | 13 ------------- bionemo-recipes/models/mixtral/collator.py | 13 ------------- bionemo-recipes/models/qwen/collator.py | 13 ------------- bionemo-recipes/recipes/esm2_native_te/collator.py | 13 ------------- bionemo-recipes/recipes/esm2_peft_te/collator.py | 13 ------------- .../recipes/llama3_native_te/collator.py | 13 ------------- .../recipes/opengenome2_llama_native_te/collator.py | 13 ------------- 8 files changed, 104 deletions(-) diff --git a/bionemo-recipes/models/esm2/collator.py b/bionemo-recipes/models/esm2/collator.py index add487c997..e83d719eb7 100644 --- a/bionemo-recipes/models/esm2/collator.py +++ b/bionemo-recipes/models/esm2/collator.py @@ -335,19 +335,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index 058b625bec..4555c1762a 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -341,19 +341,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/models/mixtral/collator.py b/bionemo-recipes/models/mixtral/collator.py index 058b625bec..4555c1762a 100644 --- a/bionemo-recipes/models/mixtral/collator.py +++ b/bionemo-recipes/models/mixtral/collator.py @@ -341,19 +341,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/models/qwen/collator.py b/bionemo-recipes/models/qwen/collator.py index 058b625bec..4555c1762a 100644 --- a/bionemo-recipes/models/qwen/collator.py +++ b/bionemo-recipes/models/qwen/collator.py @@ -341,19 +341,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index 058b625bec..4555c1762a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -341,19 +341,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/esm2_peft_te/collator.py b/bionemo-recipes/recipes/esm2_peft_te/collator.py index 058b625bec..4555c1762a 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/collator.py +++ b/bionemo-recipes/recipes/esm2_peft_te/collator.py @@ -341,19 +341,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index 058b625bec..4555c1762a 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -341,19 +341,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py index 058b625bec..4555c1762a 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/collator.py @@ -341,19 +341,6 @@ def set_epoch(self, epoch: int): """Set the epoch for the dataset.""" self.dataset.set_epoch(epoch) - def state_dict(self) -> dict: - """Delegate to the underlying HF IterableDataset's state tracking. - - This enables StatefulDataLoader to save/restore the stream position instead of - falling back to naive fast-forward (which crashes on cross-process restarts due to - last_yielded_worker_id mismatch with non-deterministic streaming datasets). - """ - return self.dataset.state_dict() - - def load_state_dict(self, state_dict: dict) -> None: - """Restore the underlying HF IterableDataset's stream position.""" - self.dataset.load_state_dict(state_dict) - @dataclass class DataCollatorForContextParallel: From a611fb48386d22c37da0741b8847816a70173a31 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Wed, 13 May 2026 22:52:46 +0000 Subject: [PATCH 04/10] Clean up hydra configs: rename 7b to 8b, remove experiment configs, restore pytest MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename L2_lingua_7b → L2_lingua_8b (model is Llama-3.1-8B) - Rename L2_lingua_7b_mxfp8_qinit → L2_lingua_8b_mxfp8_qinit - Keep only 4 example configs: 8b base, 8b mxfp8+qinit, 70b base, 70b mxfp8+qinit - Remove 11 experiment-specific configs (bf16_baseline, fp8, mxfp8, thd, cp4, etc.) - Restore pytest to requirements.txt (needed by CI runner) Signed-off-by: Savitha Srinivasan --- .../hydra_config/L2_lingua_70b_mxfp8.yaml | 23 --------------- .../hydra_config/L2_lingua_70b_mxfp8_cp4.yaml | 20 ------------- .../L2_lingua_70b_mxfp8_qinit_thd.yaml | 29 ------------------- .../hydra_config/L2_lingua_70b_mxfp8_thd.yaml | 18 ------------ .../L2_lingua_70b_mxfp8_thd_cp4.yaml | 20 ------------- .../hydra_config/L2_lingua_70b_thd.yaml | 16 ---------- .../L2_lingua_7b_bf16_baseline.yaml | 12 -------- .../hydra_config/L2_lingua_7b_fp8.yaml | 21 -------------- .../hydra_config/L2_lingua_7b_mxfp8.yaml | 22 -------------- .../L2_lingua_7b_mxfp8_fl1_qinit.yaml | 28 ------------------ .../hydra_config/L2_lingua_7b_pure_bf16.yaml | 12 -------- .../{L2_lingua_7b.yaml => L2_lingua_8b.yaml} | 6 ++-- ...nit.yaml => L2_lingua_8b_mxfp8_qinit.yaml} | 8 ++--- .../recipes/llama3_native_te/requirements.txt | 1 + 14 files changed, 8 insertions(+), 228 deletions(-) delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml delete mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml rename bionemo-recipes/recipes/llama3_native_te/hydra_config/{L2_lingua_7b.yaml => L2_lingua_8b.yaml} (95%) rename bionemo-recipes/recipes/llama3_native_te/hydra_config/{L2_lingua_7b_mxfp8_qinit.yaml => L2_lingua_8b_mxfp8_qinit.yaml} (78%) diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml deleted file mode 100644 index d28022dad0..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# Lingua 70B MXFP8 with Context Parallelism (CP=2). -# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. -# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. - -defaults: - - L2_lingua_70b - - _self_ - -# FP8 with MXFP8BlockScaling (hardware accelerated on Blackwell) -fp8_config: - enabled: true - fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling - fp8_format: E4M3 - fp8_recipe_kwargs: {} - quantized_model_init_kwargs: - enabled: false - -# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16 -fp8_layers: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] - -wandb: - name: lingua-70b-mxfp8-fl1-cp2 - id: lingua-70b-mxfp8-fl1-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml deleted file mode 100644 index c34d88124b..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_cp4.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# Lingua 70B MXFP8 with Context Parallelism (CP=4). -# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. -# CP=4 reduces per-GPU activation memory, needed for B200 GPUs (192GB) -# where CP=2 OOMs with FP32 master weights. -# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. - -defaults: - - L2_lingua_70b_mxfp8 - - _self_ - -cp_size: 4 - -dataset: - # MXFP8 block size is 32. With CP=4, post-split dims must be divisible by 32. - # Pre-split: pad to 128 (= 32 * CP). After CP splits by 4, each chunk is divisible by 32. - pad_sequences_to_be_divisible_by: 128 - -wandb: - name: lingua-70b-mxfp8-fl1-cp4 - id: lingua-70b-mxfp8-fl1-cp4 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml deleted file mode 100644 index c08093bfca..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit_thd.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# Lingua 70B MXFP8 THD with quantized model init (all layers FP8) + CP=2. -# Quantized model init stores only FP8 weights (no BF16 copies), saving memory. -# FP32 master weights are maintained in FusedAdam optimizer. -# THD enables sequence packing for better GPU utilization. -# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. - -defaults: - - L2_lingua_70b - - _self_ - -config_kwargs: - attn_input_format: thd - self_attn_mask_type: padding - -use_sequence_packing: true - -# All layers in FP8 (no FL1 exclusion) — compatible with quantized_model_init. -fp8_config: - enabled: true - fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling - fp8_format: E4M3 - fp8_recipe_kwargs: {} - quantized_model_init_kwargs: - enabled: true - preserve_high_precision_init_val: true - -wandb: - name: lingua-70b-mxfp8-qinit-thd-cp2 - id: lingua-70b-mxfp8-qinit-thd-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml deleted file mode 100644 index fd30296024..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# Lingua 70B MXFP8 THD format with Context Parallelism (CP=2). -# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. -# THD enables sequence packing for better GPU utilization. -# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. - -defaults: - - L2_lingua_70b_mxfp8 - - _self_ - -config_kwargs: - attn_input_format: thd - self_attn_mask_type: padding - -use_sequence_packing: true - -wandb: - name: lingua-70b-mxfp8-fl1-thd-cp2 - id: lingua-70b-mxfp8-fl1-thd-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml deleted file mode 100644 index 8547047dd3..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_thd_cp4.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# Lingua 70B MXFP8 THD format with Context Parallelism (CP=4). -# FL1: layers 2-79 in FP8, layers 1 and 80 in BF16. -# THD enables sequence packing for better GPU utilization. -# CP=4 reduces per-GPU activation memory, needed for B200 GPUs (192GB) -# where CP=2 OOMs with FP32 master weights. -# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. - -defaults: - - L2_lingua_70b_mxfp8_cp4 - - _self_ - -config_kwargs: - attn_input_format: thd - self_attn_mask_type: padding - -use_sequence_packing: true - -wandb: - name: lingua-70b-mxfp8-fl1-thd-cp4 - id: lingua-70b-mxfp8-fl1-thd-cp4 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml deleted file mode 100644 index 9143dc205d..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_thd.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# Lingua 70B BF16 THD format with Context Parallelism (CP=2). -# THD enables sequence packing for better GPU utilization. - -defaults: - - L2_lingua_70b - - _self_ - -config_kwargs: - attn_input_format: thd - self_attn_mask_type: padding - -use_sequence_packing: true - -wandb: - name: lingua-70b-bf16-thd-cp2 - id: lingua-70b-bf16-thd-cp2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml deleted file mode 100644 index 51ca061a3b..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_bf16_baseline.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# BF16 baseline for Lingua 7B — benchmarking step time / throughput against MXFP8. -# Same model, dataset, and hyperparams as L2_lingua_7b_mxfp8_qinit but without FP8. - -defaults: - - L2_lingua_7b - - _self_ - -num_train_steps: 1_000 - -wandb: - name: lingua-7b-bf16-baseline - id: lingua-7b-bf16-baseline diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml deleted file mode 100644 index a787e0bff0..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_fp8.yaml +++ /dev/null @@ -1,21 +0,0 @@ -# Lingua 7B FP8 Block Scaling - FL1 (layer 1 and 32 in BF16, layers 2-31 in FP8) - -defaults: - - L2_lingua_7b - - _self_ - -# FP8 with Float8BlockScaling and E4M3 format -fp8_config: - enabled: true - fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling - fp8_format: E4M3 - fp8_recipe_kwargs: {} - quantized_model_init_kwargs: - enabled: false - -# FL1: layers 2-31 in FP8, layers 1 and 32 in BF16 -fp8_layers: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] - -wandb: - name: lingua-7b-fp8-bs-fl1 - id: lingua-7b-fp8-bs-fl1-v2 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml deleted file mode 100644 index 5f4450d65b..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# Lingua 7B MXFP8 - FL1 (layer 1 and 32 in BF16, layers 2-31 in FP8) -# Requires Blackwell GPUs (GB200) for hardware MXFP8 support - -defaults: - - L2_lingua_7b - - _self_ - -# FP8 with MXFP8BlockScaling (hardware accelerated on Blackwell) -fp8_config: - enabled: true - fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling - fp8_format: E4M3 - fp8_recipe_kwargs: {} - quantized_model_init_kwargs: - enabled: false - -# FL1: layers 2-31 in FP8, layers 1 and 32 in BF16 -fp8_layers: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] - -wandb: - name: lingua-7b-mxfp8-fl1 - id: lingua-7b-mxfp8-fl1 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml deleted file mode 100644 index db1327411c..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_fl1_qinit.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# Lingua 7B MXFP8 with quantized model init + FL1 (first/last layer BF16). -# Layers 2-31 in FP8, layers 1 and 32 in BF16. -# Quantized model init stores only FP8 weights (no BF16 copies), saving memory. -# FP32 master weights are maintained in FusedAdam optimizer. -# Requires config_kwargs.use_quantized_model_init=true so per-layer init -# correctly disables quantization for BF16 layers. -# Requires Blackwell GPUs (GB200) for hardware MXFP8 support. - -defaults: - - L2_lingua_7b_mxfp8 - - _self_ - -config_kwargs: - attn_input_format: thd - use_quantized_model_init: true - -fp8_config: - enabled: true - fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling - fp8_format: E4M3 - fp8_recipe_kwargs: {} - quantized_model_init_kwargs: - enabled: true - preserve_high_precision_init_val: true - -wandb: - name: lingua-7b-mxfp8-fl1-qinit - id: lingua-7b-mxfp8-fl1-qinit diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml deleted file mode 100644 index 152574bc87..0000000000 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_pure_bf16.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# Lingua 7B pure BF16 - no FP32 master weights - -defaults: - - L2_lingua_7b - - _self_ - -# No FP32 master weights at all -use_fp32_master_weights: null - -wandb: - name: lingua-7b-pure-bf16 - id: lingua-7b-pure-bf16 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b.yaml similarity index 95% rename from bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml rename to bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b.yaml index d0a78abc16..65098a53fd 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b.yaml @@ -15,9 +15,9 @@ use_sequence_packing: true use_fp32_master_weights: true wandb: - name: lingua-7b-bf16 - project: lingua-7b - id: lingua-7b-bf16-v2 + name: lingua-8b-bf16 + project: lingua-8b + id: lingua-8b-bf16 num_train_steps: 60_000 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_qinit.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b_mxfp8_qinit.yaml similarity index 78% rename from bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_qinit.yaml rename to bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b_mxfp8_qinit.yaml index c6261ce9c2..b9227cf101 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_7b_mxfp8_qinit.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b_mxfp8_qinit.yaml @@ -1,10 +1,10 @@ -# Lingua 7B MXFP8 with quantized model init (all layers FP8). +# Lingua 8B MXFP8 with quantized model init (all layers FP8). # Quantized model init stores only FP8 weights (no BF16 copies), saving memory. # FP32 master weights are maintained in FusedAdam optimizer. # Requires Blackwell GPUs (GB200) for hardware MXFP8 support. defaults: - - L2_lingua_7b + - L2_lingua_8b - _self_ # All layers in FP8 (no FL1 exclusion) — compatible with outer quantized_model_init. @@ -18,5 +18,5 @@ fp8_config: preserve_high_precision_init_val: true wandb: - name: lingua-7b-mxfp8-allfp8-qinit - id: lingua-7b-mxfp8-allfp8-qinit + name: lingua-8b-mxfp8-qinit + id: lingua-8b-mxfp8-qinit diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt index 8a15cec936..a36f3df85f 100644 --- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt +++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt @@ -9,3 +9,4 @@ transformers wandb zstandard nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect +pytest From 0bfaceb2a29b7f8285fb687f0e35cc3867a96e46 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Thu, 14 May 2026 04:52:26 +0000 Subject: [PATCH 05/10] Fix CI failures: restore is_compileable and unify quantized_model_init - Restore is_compileable property on HFInferenceParams (accidentally dropped from PR 1500), required by newer transformers generate(). - Unify get_autocast_context init path to work both standalone (model tests, no outer context) and with outer quantized_model_init (recipe training). FP8/FP4 layers use per-layer quantized_model_init with preserve_high_precision_init_val=True; BF16 layers use quantized_model_init(enabled=False) to override any outer context. Signed-off-by: Savitha Srinivasan --- bionemo-recipes/models/llama3/modeling_llama_te.py | 13 ++++++++----- .../recipes/llama3_native_te/modeling_llama_te.py | 13 ++++++++----- .../modeling_llama_te.py | 13 ++++++++----- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index 1d912f1ba1..d61976c424 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -409,11 +409,9 @@ def get_autocast_context( if init and self.config.use_quantized_model_init: if precision in ("fp8", "fp4"): - # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext() - # preserves the outer context's settings (recipe, preserve_high_precision_init_val). - # A nested quantized_model_init would override preserve_high_precision_init_val to False. - return nullcontext() - # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context. + return transformer_engine.pytorch.quantized_model_init( + recipe=recipe, preserve_high_precision_init_val=True + ) return transformer_engine.pytorch.quantized_model_init(enabled=False) if precision == "fp8": @@ -633,6 +631,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int: return 0 return max(self.sequences.values()) + @property + def is_compileable(self) -> bool: + """Required by HuggingFace transformers generate() auto-compile check.""" + return False + def reorder_cache(self, beam_idx: torch.LongTensor): """Reorder the cache based on the beam indices.""" if isinstance(self.cache_manager, PagedKVCacheManager): diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index 994c4f876b..0bdb8a23e8 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -415,11 +415,9 @@ def get_autocast_context( if init and self.config.use_quantized_model_init: if precision in ("fp8", "fp4"): - # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext() - # preserves the outer context's settings (recipe, preserve_high_precision_init_val). - # A nested quantized_model_init would override preserve_high_precision_init_val to False. - return nullcontext() - # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context. + return transformer_engine.pytorch.quantized_model_init( + recipe=recipe, preserve_high_precision_init_val=True + ) return transformer_engine.pytorch.quantized_model_init(enabled=False) if precision == "fp8": @@ -639,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int: return 0 return max(self.sequences.values()) + @property + def is_compileable(self) -> bool: + """Required by HuggingFace transformers generate() auto-compile check.""" + return False + def reorder_cache(self, beam_idx: torch.LongTensor): """Reorder the cache based on the beam indices.""" if isinstance(self.cache_manager, PagedKVCacheManager): diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py index 994c4f876b..0bdb8a23e8 100644 --- a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py @@ -415,11 +415,9 @@ def get_autocast_context( if init and self.config.use_quantized_model_init: if precision in ("fp8", "fp4"): - # Let the outer quantized_model_init context handle FP8/FP4 layers. Using nullcontext() - # preserves the outer context's settings (recipe, preserve_high_precision_init_val). - # A nested quantized_model_init would override preserve_high_precision_init_val to False. - return nullcontext() - # BF16 layers: explicitly disable quantized init to override any outer quantized_model_init context. + return transformer_engine.pytorch.quantized_model_init( + recipe=recipe, preserve_high_precision_init_val=True + ) return transformer_engine.pytorch.quantized_model_init(enabled=False) if precision == "fp8": @@ -639,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int: return 0 return max(self.sequences.values()) + @property + def is_compileable(self) -> bool: + """Required by HuggingFace transformers generate() auto-compile check.""" + return False + def reorder_cache(self, beam_idx: torch.LongTensor): """Reorder the cache based on the beam indices.""" if isinstance(self.cache_manager, PagedKVCacheManager): From 013479301f03fd541b394b34b7b779fa3abb56c0 Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Thu, 14 May 2026 05:30:58 +0000 Subject: [PATCH 06/10] Fix test_quantized_model_init: parametrize across all FP8 recipes with xfail Match the pattern used by model-level tests and conftest.py: parametrize across DelayedScaling, Float8CurrentScaling, Float8BlockScaling, and MXFP8BlockScaling with automatic xfail for unsupported hardware. Previously hardcoded Float8BlockScaling which requires sm90+ (Hopper) but CI runs on L4 (sm89). Signed-off-by: Savitha Srinivasan --- .../tests/test_quantized_model_init.py | 62 ++++++++++++------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py index 831bfe4ee8..440e588bb4 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py @@ -20,8 +20,8 @@ 2. FL1 + qinit: FP8 layers have QuantizedTensor weights, BF16 layers have regular BF16 weights 3. BF16 layers don't lose precision from an outer quantized_model_init context -Uses Float8BlockScaling instead of MXFP8BlockScaling so tests run on non-Blackwell GPUs. -The quantized_model_init behavior is recipe-agnostic. +Parametrized across all FP8 recipes with automatic xfail for unsupported hardware +(same pattern as conftest.py and the model-level tests). """ import sys @@ -30,7 +30,8 @@ import pytest import torch import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Float8BlockScaling +from transformer_engine.common import recipe as recipe_module +from transformer_engine.pytorch import fp8 as te_fp8 from transformer_engine.pytorch.tensor import QuantizedTensor @@ -52,6 +53,31 @@ requires_gpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +# FP8 recipes with hardware support checks — unsupported recipes auto-xfail. +_FP8_RECIPES = [ + ("DelayedScaling", recipe_module.DelayedScaling(), te_fp8.check_fp8_support), + ("Float8CurrentScaling", recipe_module.Float8CurrentScaling(), te_fp8.check_fp8_support), + ("Float8BlockScaling", recipe_module.Float8BlockScaling(), te_fp8.check_fp8_block_scaling_support), + ("MXFP8BlockScaling", recipe_module.MXFP8BlockScaling(), te_fp8.check_mxfp8_support), +] + + +def _parametrize_fp8_recipes(): + params = [] + for name, recipe, check_fn in _FP8_RECIPES: + supported, reason = check_fn() + params.append(pytest.param(recipe, id=name, marks=pytest.mark.xfail(condition=not supported, reason=reason))) + return params + + +fp8_recipe_fixture = pytest.fixture(params=_parametrize_fp8_recipes()) + + +@fp8_recipe_fixture +def qinit_recipe(request): + """FP8 recipe for quantized_model_init tests, with xfail for unsupported hardware.""" + return request.param + def _has_quantized_weights(layer) -> bool: """Check if a TE TransformerLayer has any QuantizedTensor parameters.""" @@ -70,29 +96,25 @@ def _has_high_precision_init_val(layer) -> bool: @requires_gpu -def test_all_fp8_qinit(): +def test_all_fp8_qinit(qinit_recipe): """All layers FP8 with quantized_model_init: all weights should be QuantizedTensors.""" - recipe = Float8BlockScaling() config = NVLlamaConfig( **_SMALL_CONFIG_KWARGS, attn_input_format="bshd", dtype=torch.bfloat16, ) - with te.quantized_model_init(recipe=recipe, enabled=True, preserve_high_precision_init_val=True): - model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + with te.quantized_model_init(recipe=qinit_recipe, enabled=True, preserve_high_precision_init_val=True): + model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe) - # All decoder layers should have quantized weights for i, layer in enumerate(model.model.layers): assert _has_quantized_weights(layer), f"Layer {i} should have QuantizedTensor weights" assert _has_high_precision_init_val(layer), f"Layer {i} should have high-precision init vals" @requires_gpu -def test_fl1_qinit_bf16_layers_not_quantized(): +def test_fl1_qinit_bf16_layers_not_quantized(qinit_recipe): """FL1 + qinit: BF16 layers (first/last) should NOT have quantized weights.""" - recipe = Float8BlockScaling() - # FL1: layers 2,3 in FP8 (1-indexed), layers 1,4 in BF16 layer_precision = [None, "fp8", "fp8", None] config = NVLlamaConfig( **_SMALL_CONFIG_KWARGS, @@ -102,8 +124,8 @@ def test_fl1_qinit_bf16_layers_not_quantized(): use_quantized_model_init=True, ) - with te.quantized_model_init(recipe=recipe, enabled=True, preserve_high_precision_init_val=True): - model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + with te.quantized_model_init(recipe=qinit_recipe, enabled=True, preserve_high_precision_init_val=True): + model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe) # BF16 layers (0 and 3, 0-indexed) should NOT have quantized weights assert not _has_quantized_weights(model.model.layers[0]), "First layer (BF16) should not have QuantizedTensors" @@ -115,9 +137,8 @@ def test_fl1_qinit_bf16_layers_not_quantized(): @requires_gpu -def test_fl1_qinit_fp8_layers_preserve_high_precision(): +def test_fl1_qinit_fp8_layers_preserve_high_precision(qinit_recipe): """FL1 + qinit: FP8 layers should preserve high-precision init vals for master weights.""" - recipe = Float8BlockScaling() layer_precision = [None, "fp8", "fp8", None] config = NVLlamaConfig( **_SMALL_CONFIG_KWARGS, @@ -127,8 +148,8 @@ def test_fl1_qinit_fp8_layers_preserve_high_precision(): use_quantized_model_init=True, ) - with te.quantized_model_init(recipe=recipe, enabled=True, preserve_high_precision_init_val=True): - model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + with te.quantized_model_init(recipe=qinit_recipe, enabled=True, preserve_high_precision_init_val=True): + model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe) # FP8 layers should have high-precision init values assert _has_high_precision_init_val(model.model.layers[1]), "FP8 layer should have high-precision init vals" @@ -144,9 +165,8 @@ def test_fl1_qinit_fp8_layers_preserve_high_precision(): @requires_gpu -def test_fl1_no_qinit_baseline(): +def test_fl1_no_qinit_baseline(qinit_recipe): """FL1 without qinit: all weights should be regular BF16 tensors (baseline).""" - recipe = Float8BlockScaling() layer_precision = [None, "fp8", "fp8", None] config = NVLlamaConfig( **_SMALL_CONFIG_KWARGS, @@ -155,9 +175,7 @@ def test_fl1_no_qinit_baseline(): layer_precision=layer_precision, ) - # No quantized_model_init context — default behavior - model = NVLlamaForCausalLM(config, fp8_recipe=recipe) + model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe) - # No layers should have quantized weights for i, layer in enumerate(model.model.layers): assert not _has_quantized_weights(layer), f"Layer {i} should not have QuantizedTensors without qinit" From cfd448369e1a38a2422b24a3dd0d4de8ca8f81cc Mon Sep 17 00:00:00 2001 From: Savitha Srinivasan Date: Thu, 14 May 2026 18:37:16 +0000 Subject: [PATCH 07/10] Add MXFP8 benchmark images for README and PR description Signed-off-by: Savitha Srinivasan --- .../images/llama3/lingua-70b-mxfp8-1node.png | Bin 0 -> 269042 bytes .../llama3/lingua-70b-mxfp8-multinode.png | Bin 0 -> 232840 bytes .../llama3/lingua-7b-mxfp8-multinode.png | Bin 0 -> 227957 bytes .../images/llama3/lingua-8b-mxfp8-1node.png | Bin 0 -> 228791 bytes .../llama3/lingua-8b-vs-70b-mxfp8-uplift.png | Bin 0 -> 140587 bytes 5 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/docs/assets/images/llama3/lingua-70b-mxfp8-1node.png create mode 100644 docs/docs/assets/images/llama3/lingua-70b-mxfp8-multinode.png create mode 100644 docs/docs/assets/images/llama3/lingua-7b-mxfp8-multinode.png create mode 100644 docs/docs/assets/images/llama3/lingua-8b-mxfp8-1node.png create mode 100644 docs/docs/assets/images/llama3/lingua-8b-vs-70b-mxfp8-uplift.png diff --git a/docs/docs/assets/images/llama3/lingua-70b-mxfp8-1node.png b/docs/docs/assets/images/llama3/lingua-70b-mxfp8-1node.png new file mode 100644 index 0000000000000000000000000000000000000000..ef5721e5780665d1c92b67b9dea041a1b3c1989a GIT binary patch literal 269042 zcmeFZ_al~l_&0u)N=9iYLWu_1qlA)BDI`M3&dy#TtE|dQvdRb(cTqA>l1QXo*REc=Ng|Omkx1JG zcJ9DePNzOx!2d+;FRR%rSsB?o>e(8SuIt%bn_JnNo8CEk&(PM+)XGwjho6r}nCqyC zy}h-aC@-(Y|N9d>R<_2xWf~5RcohoktLk ze$LNtD~DUJxT~XwL8wnpQd-k!v>5rbT7%Y;f>8!nzR>F*_1+`7ljVQ?<3Q>Ghpqqn zm+%wfymzgS|F1tr{8nlJ0~5vn_RIJxN8%1?;+y~L>preaEerWyzd(FHLE!(q!fe}3jX^70~+3FFwFrAB@D@L{#vw-fX# zd?STzSQ8QwW(qR`0;s)xeed*_`;7faXPlgzbgI5kLRllZ{mx}~5tXqLw^d)m`tW$u zw&d1za&oH5+k-0$fBY*R(Tr*&w&{4R z`wR^kkRDt;UjF=)l9QeRCkJZ?yUvg6wrmIe`Q6u?#y8Jec3)IbPr- z(bpF~3x9ql+K;yu+C5+h@(;#)>verJZPn4{T=FKB#!EdC9#ye1Q*IhBe)HxHg_M+( za)#!i7*U6E-(8GTe}10Lv!Uo2wOw`FSc{@-8m^DXx>HHHw6yffcDPPKNvU3`;^N)A z96}ad-a&`YmH+s0@c#Y#&F$@vVqy+=cXv;<>v;4JRQk*O&HLqFsDIAk_cx}9mjZ?| z8HQyI4RYV!ss^hX6}YW_?shV6jIQ;!s|q+2%~u!ef5z;GvQxnc@oOTVB)oPIt9#^v zYvka@Gl)%Rq?Xu{2j+xxOD=95a{^Q3FlOOL6jeLFi2+Pz>t~t@ z=eG>4mmka}B_&y2Kie}StLwdo^-*RfFInkw<(Dt)g@pwL=Zw{Lbbf7>oybhZ5)`iU z&h3A4$-Fz)#`(6YoSa<7L~Z6^N=?4E``^ERJ*Txw++)Icv_z{7ugl4Gv?NMzkD#HT zP!#31w6x@6VPUzk%EHR(iCcVaaI!hxOINsa;%cWxPGI0JU9-L2Fb(~#abu`k` zW9jiol8rFZDRJj+cHpp;-)O8>WT)xOcQ6Y{l8!Qy8|&b((>bA8GPCiV+pD&K{NMq@ ze=qT=CmUnVdH&C4q3H2a;g;ShVmCs~sgx|2VUwSmo0?1RK{FF%QTyzKx{QoppTMsY z4`U-EBg(@S)xpeyVq#+Qk$1nRQB)m1a-{L}nKRDSyvL4t?U(PgYiVgw?Ee$KsHCcD z9Ljd%U^nwy)l8#1S!(f4Yinz2_wVQ5As?l?8M(Oi^32GC2PAoVM#h(0uhnOsOc5Q3 z^Vm4$nCG^-P^%^X*;UxMkx4sJ{euxj)w#Rh*tI*tI8|-S&3p1y#?R}IwA~1@4iY`HAId{$!@w=$4r)YTZXFEG(Rw8Lm%|ujA#saN$V7y+8Y+ z>1b)aX_zmt)86o*Jn-J2c3S zhCPvy^lBO!bT#=0xo>|_Z;BOjp)?Tl=H%ns^W@1B&)Mh8p7!P56eb<-S&X~p#@lRW zN3J)1eEM{cb|LrekEJ-lwA!i?_t#g>-2QC-1SQcnIhf%LC7XP-0(%gu-`rS>7u_QP z%C{F*KXiRCreP~l`)H#0@LhFR*NX`9{QUe8k>1Z1KNFPS+-%)!PuF-9%q%`4qAgjI z8BFx@*O445X|#n!MOlXR2cyMY4&s(nqoilKELsYSh)kDT{rt>*V}esX`#2Vn&*8T+ z(*;-hTeogG``tJazcp*}#)8CN(%SN9T%jqp)a9EupUruoWNzEBW1!57eCO`nSFm}K zW?S3ZT5_y(UMDAyyr>E}CN0xuyZBeDAy({Vd)PSEWawU5=*zQKy?xE`QmvcZTwHzQ z<8N+i=4E~U?3IcH%-zeie%mMV&YYV=FsZ zj%R9WYQA3l`iO@4&S+x{-W9muXZrH}ARrQmE3KV*gZ= z_~y>=Fx%OsX?>z0G9;tZ5}!2syz@Ssv|Rh~efR!!`M9sI98gmqcYN=4!VD#T~+>ys)VzcPjnzs^ZqW53U?OcTXJ9hzP>w*H^R9!Xl0r*FKKAJRQmA= z&Hhis+8raKD)k!+kEUPu9Y1k`BK9YGTPfPlfgtJY*LVHQv3i6pXn*_ueNo)bkY~@f zl9pzMsa)OMMA~DYJfZk(-kH$ce2|Kcr-8o-McOLE#{O@&I>W}s`f5j!%URx%m6dr; ztmKJxac=Hi=N$hId7n~D*hVz6)KqrOe0z#NU7rRUZd|)o_P3|N?t1Xk@^`=sX#AD6 z_(%=d)(i{`h36GCG|n|mr@eh^_gCfAsZ%AxV!nG4JXtr&pY6FBaNtTr8t#k{?v|0Z zwl?wG4dvyYPb%*uCni?Bd-pE#(Iaw=Tx%Bp;NXaYBb}#rY}@ueQ>P@mWOL(X=g*&l zIPFFj7GHNz)0b)&yPneei(U54@n3I4U7eq)sc9CD`3HXes+vC>f})}?xU~x-on2f$ zbbT`WkZo8mfU1@J=1nb{MPz8`enwGy_5zzh-}2R~DmaXCva^xQdcYlBrW&(^SUrM$e{R_lt2WP2j( zXlj~wV+e4=XVbQmKi=y}8JU>WeEjq&3W)5bn9JvG%ci{Ez9a+>CL*#_2q>>8bY>1S8o-&IukF29jJ!F z-hulQ<|lniemP=tKC+G1#Zg^y;^1$2&XSr8*;EaQ_>75#$<7L@i1yWJZ}Z+$jG3%j zX$vT#9|EGH4!LZs<;jdIeg`BLHhrIy6T$cydn@`q%FbXz@Z-n2UORT(`Tkbb_=L=J z-QT}|H`LTTR!LQOc*dgZc_6cR?D}kz_#{5b_H6F#>Wjzs*=*44+BNJQ!JSj<8{tSj zI541wZZD!&AI@EW!DFMLFgwvA6>QNWc2^jZFZ&;i zb_RkQj2U%zcaP@l>gtk7Im_1h&vE%*9z$j*fKf4eK2my2C@~jv$G?X ztTWrNv#-pn6IFFv+Y667?#|t596!` zP86x1o9tAwbPs8BM_&}=oqW9a7`=V~w%61MddUsGBk?>cVPc;pe{w5lI?s*m<&8#7 zv){l%InSb{H9k3Z;)IW>nVDTV^^x;xZ{}p!mEsO`Nuzq0w%`_~Vq)reu!U@9nCWXngMnFlhGAo4qnwUT zTo}8e=l67ti1gc^{0t2Yrs}>(F1uRJt2k@LyD!i9tgkQrYA7pvI6gMUM@B~WZWe`vI?uk_ioXHFAQe0rXerU!sq(YY!Av8fzCd2*)8 zu6Dft35Q==nGJS%q1*kw(kp+m_x39 zDfZaRXchYLMoBtehD-l)TVJm8tdWs{!STg4bp4lo`K+w0k$|{;ws-H|HD8_?_UY=2 zWM20FQu>fgq<+YuTiNiGZBUxbLSz7w$g{k>yxJa}*x1-Y@(cTbri?{JMf3j7CSQM< zon!TjvMLKjrqPM7(*vh?D3ERIPVsK%#lJCq{#&z4>cTi>neTFNh%+)R{`pzm@bzn0 zey{gH=SS21|0EL3-KYPRY|hMFir4x2;!GztuG7TVrXAJ#)j_plswg=f=t6d4|EO#| z{&wdHR=ZkbpHWuMo;xcH3>`uD7sR5_dii@_q())hneZ>w)q=Zr?P?Uz_-w8hz#kqS zUWJWUalvinIj?E`HGBKhRvzAu9?g~mk^~;2r?)P*=q*$i^}sRqr0Mf<&#~;`cxn+M zwgV?`ct|39&h)DKK>B1YV4O1H5P7;0+Z>8{WSx;;L$_FjV?o9oqPq}pH?3CmqUpd02xH6gZ{ z6lP{-Wml>Pe)QGQ>l1bNi&pyWGiBBi5)#qYAcCoC4Jow_VM*l#P|Ao**ELm@lxdSY9UvTJIRAl)dVK;F)|&}FT5YF#ueUJ-_S4= zMmE|M*Q9;{d;urNuDs#p*&F|JtipPV4+$F8xM`C;n4f6B!5528C*-zbpPZ2qJK?-i z&Fpct$;uGx#$7Gyjh4Z9?jDy_K9>puQv|P0qf7)3w_}f>apO>W%27t9wzf9sQ59w7 z=Q68dduCQ&@u*O+i2DFlzru~BoIg}@@6YW^moHN^`uqD!E&lD9#>@9%V|nU98Q}BS za1Cp#qpx^&LiRnjZP;K9-sSy@?t_hsoCIq&U8EzZkTYL}d-==urDoRW;CCIt1QbTxeiTCnqAs_WX)UW=}eAufSN zhK72$cX>fM>a?`AItJBc!k{A(Y;SAcsby`&8b_Sl1Zelv63)~rIEdy@4*YDVRh?%y z+Fa;tCm&H$8q6P{pZ!aoY5zZu3B!H`Gth8;41ytee%wE zY@8Pm5D;!d`ESWKl%=PqFRiPigmS=l_N?^g#`^5(zbf(^su>MGZ)EojZ7Bto5phl} zF20cOFtMXg+PVGLuekQZ#)@q%EoCStNt?`I2K@Z|qzBT$NBWnh2V0%V$ai}O1qGGc zpnSx(tMor4BcJ;HowM=dE=Hl!@^ZLgb@#F6`0WvWuISar>?eL)?+QzLSJ+KuYll6-RPfJj{`Y~C^3iHX zCmOA?N7sv77Oh7c4{Hl|y1qseXwA5+nyDS5-T(0j>Y+;IBHzRu|6guk!ZRn`d*bivSwhhj?7vqhMPQ%@bJTAG_LLU$5wb05CX$8y|5 z!Rr{y>#L|p`lS!H5SSfe$({oTd?0NdJAQmy;+4PwhaH*}II>y7KJUAm1ruXo^6rVeIYgef;?G)5f!B&Q#(PVZ+b5s_Q6H?%H+P z$w}C_3sq8h^r83A79hh%(jffNn*cyV7doD*6}yVyg>;RLAN~IRRv%qF`q4dHjarGj zTaZnS9q5q!iCR!JXmcv3GCa%ar9hl*@G%G)0(B>P`P#K>!mdkv#$BGKoP!%f5tE8l)n@KlRmZM6(y;BdJ-NccnWuTce{z5N>K1j5q0nnJp3vQyk z0h3a1fq(STk)usfh}E+0Y-c@o%w}fj+G(>N0fIUL7w+EEuL|ftBea|0^iCAQDv-SX z(a{L~#1~%((he>=z6o#^dB*G@h~`5Orfu7|+Z@rwKHR;rv5{BB8REv(HK|{@wVPk} zL`n4PloU?4B8u9Ynu`D=dzhF4$Z3xBwI<0%f;W)*fQ;z*bVQ1ZkS;TcVMDpils2)HW7%7@-OV>E7S#WO`aB3w~P~}F>88Z)v7KzyK!Z_m@S_LO_ zd9MUAoXD}!RH5V5z9c6{E$aBUA&Td8P_+4SGP_p6z3caXeay+hxfMeNBp-Lx($d29 zOWe3|BYp`&ba$LUdzyd&bF=>?R{MXwhuLn#4h=0%PfvFM?%n%4Ni{HG1FKJNwv-u%F3pPi0(B$npcj~XHsi_5x>M!s&(1wJC`QY>D zV*3c&3~V!wO-?>Jnt#vU{tKbu4#vwz32X;a@kYrKb@+X}hz*=ne`ctb<4;b4fb)z& zuUL^%iu_IriZXOUh)u$c&e3u7`ps=^71)(Fn;UME-FcyBg(d;S62Gl1PIh`h%+an; z1sic*c8uWBnlU^_dIhd;ZD%(q7y6aCx5OhD&Cxm;Jw<25ae@EFmRDF*LO~$T6F7Ar z)|a@&kZWI|<64p?T3vNRrbZJ95&PI_{M)zw*iZUB`3|k` zfslmnP1195bBD<*X-|9z;D@-=kDdECAb_?Z$DpbI@#CGnvj@N(2;~NtOW1a3XJ+xe zg-Lbuu8-IH?YD?COd2*(fwV$!-Rt)b#zkily77dxb6NK9p4~1v;6Uc?nZ16f7lg$~RLX^L}6t;8C%|zXTPu^Kl3LHKQoU4f2`R$1pa`AHcy-~A{`mN~ zvv%;igt41vTvTt9B7J;(x-Df5tKXY{OAXN#wz9J7b+7LYOBCn{AnESdv11ZDkHbn| zSNAAeG35b{A^KwWeC2f2%p=B=0qDZBe-!*YnvXT)s^!_ty2+C0pzH2hR+7`u4eN`n zEPfTo(78_!R2Etj(A;fiW8GzZZ4ICG)Tt-!7KS)T&V%#R^n6F#XN!+6rYN}c*^en` z=S~jhjTmo)Ddo6H`sK@h?{>EvZ5)1Hkgb%=DA^~S)8wb{S@t`67iW^3GUlk;@Vy|b~Aku{8>rDbIj zZ!8$a-OpQzRR^dLhKM)*RC2>OXOLr%uZd?~f)eTj3V8 zZW@+oZ*zCO7XRkWsr3>`AId^L3JQusQ;?kURK>);fq|_!)-Z--+c@Tu?=H-pWMh*p zVBEiA_^}PyTH*(Th#<`ljFg+ zx`?=C|H%ciHGOqW0|SF0Co!YV-&e2V;43D{9BBNA8gI7|@s+uvyj*xwLrsk!5YeW% z6aTfrF0@?h)ddSuDN3HRX~>_G)P2PvyhBxijD>3-JvKKAa~j0XpBL7n*U!cN2(qEu zW|(=_8IAMU;sk%^;IVyAwq;i11ln!zK6cx;cW>G1eH&n?Xt$%!4_+Q_-)QLlqrYHZ zw@=i{UlJg&jDB#9;jE?Rx7ug8_xsMDKd;O&L3_|yC%x(ASsL^`r`eHR|78E7)DFL4 zba#e~H3e#8oou+-CACn`Nls4Q7-93ttbIh2objyXB@{@0z4C45y@iL+u$~{Gf>X-P zo$mU}H`-uRW*$Ay8Px%tW3?$qDgU2koeV#V_v-Q!zbL?Sa5&-Zu~++0yR{c z$=)JH;={{%Ob=|2j)~znsM*sy#o6iO?fptK&(^cYfs-|Jz7z0kw$7;T#~J*H^Jx0R zhg+=1e6a75VnYFF0%uE-l1{*M5N-qO{Q5Fn2bwP_OO0z?2L>l}?@{klldjhm)DE8i$2Cv+?b0^4chI8gsU1MX=Y)M824{w0^vDeM? z{MiQrN;X!gaH&-ghpmA+AuNn`CEZz=`}?az$MQr^W8-dsYq|jaslGDZ02Uld-15rt zwiM?c2{f60SO*RK+}pWREJFXJ-Ka^p>@OVX1q3M)Bz5eZc4A>xsdw_|kC^KhFFu%i z_y?_J@7}!+0aUHaxzw}0+b^!X-E+T>)g{oN3-L<#va74Xv)l{Dg*!TA@PmvWQ zqdqjbLY~&fIPu^pL1PZuIP*oXP7f#DHz%6Rs{NVTyEZ$A{+_|^3|GvIFyS-owxx0PnjI?e+K)YXGzOk@gI&ikNeb?=5>o8@)DwMB-8SIZ{&i z0K$XLUJr9APJ>eXtD6Uxo%+2L*0uP+}9pMxYGqz z9P9eTiTb-^|Ni}W>y*-}%&)ZoCrQI~&!5v0BDLTR6_53KI_zaO;eYTk!2rTrZR+kR zMXj@){+yif?fduszP<LJz|>&2L4@lF*CKky8tgnL7Z*CU z_C68Cq<1GZ!oyVMXvk*!CAlVwr2!-Y%D)=EA3mhh>@$Wx;sK6rkjg8errHr|?{rGX~Z|1A}JDwi)XYBKs zNA-3JT-`pO;0Xf;`tIvZwv#6Za2K35)qsE^zR=Wtk8gvJ4{%PZ*Qf))<0)N zn6>b`UO~53`)scM&|FG>$gMK1LDhp9ZXQh%5p#As_cyF_3Hf{3T#8(y8_qws(sPUi*Pi1`Iz zM2IlG_JV?f*aL5Aje#1&d9>=HPQEmTP09$~r^=y?@8&K0Bn&d$gLU^phfuh{8VDf* z8}JcSC(7PT`PlQdX6JiwN?>94Vedg((nC`#G~Ss#_~yULt24b)&s%BptLBH7$9GiE z)#4jL!l;@n1W67r>)_1C&ONo8f_p!!pF&{`Lt zDut1Y0_H;)7#viQxnXk?5tiPM(C&Ia5~5cZJC8;S`BLCTU}fCwyl!DiR*w)dc_41( z&blgCezP8Oasj$z845>En6mE|4fZ}_f}Jh*<|9e`Na@I&kKfQ)gPNyIgFKD4_vu** zZa@!YTe@r|UrpYp6DPUrqlm*QIeV-QZ6zITUb@a$cYU$ga)`(JT`-%PWpNgk_?m(T z5z=FCMdNqqck)tEX1hpq^3tIRK0$r4nQ4dNV&}WBCIiC1@clXVV%KF)njUUlz5FxnE-NaVW8CI7> zuduK%-DB4Gn8n?P!o&=M8IP_x6cqA)w%tchUo#ON#b1?2W4!rHeSz&C=TYaGiP#$( zcn4A7J>R)aZ|`kP@EzLw*fU>t?K>f(Tlf4Fh~5|MAse756k!rkvqq_=@RJZS6Lf`G z$t8J@Q8NiEOx1tuc3qy?UR+!pUHGk`ff^#`=+O9~Lx-lOr%%n%wa(4YCoTIRB?F1y zS?6N?v#*<^${A{v8UKXTrn+#$;jsNab=i^*WX?p<65yC}8NQxft-7$ZU-B^*5dfrp z0)cp{>cG_EA|W^^Km#9VoUC(?S5`16c>Y$qNE9aYc4K4X=bWm$Aqqibs^Q-)E-s!I zW?w~b%oLKEn#xq*Dgz&yXvA%Xzp{e@1BU}OZYk_a6tU1SF-h$n#=(8>zJ9;)BR1dg z%8pd?vjPG9#4#5917#A%xGeUdjLGCVm0qEDILV{3wK$P$b2M=l=H^aeAJ9i^iv>vT z=a=dOn^n`Z1FHyrm$lmFS)&7KNMy4;O-1$%uOU@vdb&kWyS~KG7DATV{8i-g$cvoD zHrdd~=nl|rw5hMJZ!7OBg3L@?c4qsW<<8-K{rlUSiV2>%UmlxZ=c&R#y|q=p`qf-$ zq&!;UGMaGQ#g$&a(3P%FCy_Pev7kGQ%;oVmwqDgCPN=c`-174Bg}x&3b~i)T5{^>* zVUD-sUCnjIw!OygqV41BJM(u@q{oFy>dKWJy#KcOZCrPACf=T=g-5#(n_5gP=4^53 zq-ykXZFTh{J02!+cP5yGBiFhiWQIN__?j$@H zYUT^kf)-0t{TwmO2&nZBi-C$umAJ1H5`H6<jU!A85#C z&)bpj!1>Scgv$N>r|w=>=Xv^-W@BZ7x3MFMNQbo~Y`4nG%@uAeZf|Lk0Lmpq%KyS4 zDwzs{wm!Iy?P2uRvG*`nk97Q&o8lY9Zkuf2b#ADPN*KwQ!V zPL1#)o53CbmSu0=+?UyY9FV*$bH{M4b_SXI+?}K60_xyHc+|zK7&Ps@JbSV;TXTV@ zg@WbM&doELk~?e52){4d3QSdaM-Z33zP^cxiFNP3UogCJA$~Z-1E9{4B6dc_!H1JL zIxVGul;0x==yd=70Ps8kQeafZSWOTfInm}}7ZDN1)Z3pp5VJ$b=>eebJME$uXLoo5 z=vryA7}gXQv90d#u(cIn3&6I~b>yk*dap;8sc5~nG@ZxOo~n2(H#gV%$QrEXBTyTy z9ULwzE7Sgy2#30sZ$JLbtQ$r~xTTql3?%?;Fp^)|xzSZ}q17E7^k_TLb-lp4y88Nw z)=V?KAD=@$QO)(wTpO5}@CRj+-6Zh(i?dci&;lyDx)=~?DGfLzXd9SvGfAfT$B*zJ zrR^kmcfRn2VR?B&>LSE9kt)SK%k^MpBAfO1q@TyZRj~bP_8~5@G??d?ckw}XF4#$x zFtmlMzcw||0Kq)%J8OxKt<8RJ7|t@xiu5<1|UK7D#q4j=sI4AtsuWb=zN5t zE&3xMJpACw;=f?@7ul3Q-Ff>7F9&X~0x|&Dvn)J3yF)@kh?pmoafM86F(TnIZo!N6 z(x@3MSX~5mc3{Kpd2$KbWW?&cqTV>{#;RcECR3(pL@Lq7eQ%}TW)TrNgiOMq+dQGJ zHm=`D)k}hl5GiQ98`+@uh)VVPGiRd`w&d9v%je=)zQVVOv!ZwRt{)+RGF+nkiB;_X z^Cu7=Pzw1~e4YyAmW1^{bsr%~Dzn)Xye2F8xqI)SLr-9ol(w|&1;t;!N7Kj)c+z^F zAbq^rg{8oT&N>XRo)2`W6@%~%46Swsdzmww$o2?JijRM|_t+KP*4NisO~Ebbk!S%3 zeKa>`PY{g7r?>BMeOxmzGV;atyH>p{FDpCf`dtxV2>Wy<++|_%5W$uZH%h?Xw6?as z2+iviUKF%b4;E~pp?8%7w+Sw9=(PQAdcKR}!u$5^YfibrEU$kJd_QTI3UvJrZ1Juf zE8dvRm3cGj!=khWj{h2^epB6gvhRe#o0})(^;@9bKpc9A7A2vf!Kj*{*&m%*8ccWD zyZA1LC9#lM`qFDyzgAPAcqvG$e0RR=92*}eNE~7v7JHL#p2=P&LcCch>u?1xb zq>J~PSy-q-MGtxLg7U#HzOm6!wvJLLBdn9ZN|+%GT?T_v$hvcx{?H*7q?!P{hqO5J z+~PG9@R@yqlD97L-%1zI$+$-R!nu2seR8w&^OqsqY+37&hR!oG7oH6l^P^e2{Ckb| z0m(5=H^_<5>&RglOQJq(B_p#zqC_*I{<`}&{+}+U4!!Z+*Bj|!B=hvnOrOV8YuHLwQJ@8580Mj50 zd}ii2%^bz!`-bUnxDWFZ!gut0Y!};#k+5cmmp_|&dbEb;hwlaPRjDEBY!#D@(weUP z;;qWtDj=@%C&#$B0ss+*bE6@qTfHzZ4uATz6m|MG`>Hj}5|OHx>)j2R6W3duvd_V6 z$XFFxZ31{C#LNBCMn)$;eEOvFyNg!z0>Y6-5JN?(5|KuGc}rnALg9Q~5>4O8r+w(J z{8b%@F4_i236<>Y+M!Z^)K6igt8POBNRO8%el?oIAov5TxH{D6cUTTe-d|u^;Zad< zZr(rOu(o8AH$hjIc}RZ}P+jKFBNREht6he{>DkNgSnJ{5#N>!Q+I@JyXX00hIBz#3 zc;|kDid0*xnoPU<$P3BDwOzPyfg@cucGB=!;qx=zuuT;BGMhH`*ZHgD(bfDg^Z9%^ z#8o-Oi8gP8vK7K1ORhX3A`s$F=E_+b9_rN5)zvLd(TKYLHm$zIdDb^7GSa#nBLP!& ziIBA3At4%V4KO~5czv`r0e6&M3@2%%YQEa3@;eZ{UeQ#m?wj#}PU-CSaD^fwcP(R~ zf%qxA`S`#tXn#W2AE zhc_}iPYG}WXlHFMzEMg^aGo9UiDX@dSL2N(s*m|MHWrzwBg|2O?uBS0fj^(kk5=LX zIqBGQs%X0&YNJR{8_YXR`?2PSXT`*4cu6Lm*-0`wEp)0NMf7Jaqg!MS|8=pvg@X!8T7 zzkd0J;vM)~>^jBgjR8I?Z2u^}VxJ#xBgZn+GBMw}3YxL4z>6n?008LgwO@Z?Ho&H@ z^r57@N-r$)UBH8!+4EEmAE?)sIBFXkd77qImzHS9B6_yJfqE9r+xP9;xAzu5g|tN) z8yW)I(>LLE@;goSNnT4y!)I<9rxa(yA?(ODq(Z|?Zq4^gxPY{RtK3VW>+zPl1r8Iw zxY0hdO_R4WG^6wnm)F-*LAX&C?M)DJUt7|w-Wqmd)nT^je6JYG>#6^kZ4KIIKY9@T z1y^AN+g@R=162C7RWB1Lh)U{eWjX76XzhFU?Cx%$>roCfm9qxX^!|8j8{3OwNk95M)@1APt;S9OH{z2ce#C1uv}x7K1xo z&1kPZ_OKY;fUuwFbmmwkkRlM*bpV9MhAXf+(3} z0O68o@Wh8hXjVB@VGHbjVfQs6kHwZl)dAU~AImcK?ah9fa4rf6+Gsc+kb4vr6$|+_ zhfX0@j&ARR^urtiHVDP>oj&cIpD&7LuLw+l4ZOwwAh#s4TArSs%3Vep8XB5K_Ya^g zUc&wh%c!AC2Dv_V^5hO86%FZ%i;L^%L4VkqRCp(NOxv*xmyuKeRcEJvgdczu^xw!0 zNGn3F?2y;s@OS}>aIDSvBdrQO47WuG^GT9eD&Sj0DnW}cp;tTA4}Bv={$e_O7y&%7 z00itIWIh1el)HYt1&TP%Y26=&GdOR3(iQDcO8Jc%=sSM?q`Rwc=Dx?jlIYr4jYo(m z^y8{A2(aZ6tD+(sw<>VFu-nS%jrB#H@yJ$y_GT~~hgvQ#HxSp z=4NL8h)huUwzs!SD=6$4m6%6KOVWQwr;4Ey39NOpT$CrW0A+Z++qZ8wH#a}X%Hqe6 zi4X243l|qPqDl&5NpSY@wuP{GDr9smxoFGnNXRnH`bij7JqrbwvYnDOR81`a724D>H2ue zs3|rOlytUgMEQ#1H{;8!d$(@gI)z}6TqNHXY;PYB2Mn+&Ag!#Ne)|B~4(d|GVBS0a z(?-fE_2UstsS&K#?8m#1pj}qFx({%=?!YJ~Bpj=EKjI|tcN@Gvu}Kkr%@gW|(Pry= z5$oR=K^L==rR2h1&c-A&j6Dv~9oV~f3)G5nsSJc*1`xTu9G&KQ0ErH4f170J$kYhB zUw)p@?Y6mb_};yHPB^_JB5DQkgX39N#?Iv!F0Fi8@Tw8H0*}oLY07V&{hL6FX4q{D z^Zxxxy87mOC3h-Dn-AUs_(2%+*6AlK-dERy=GC77v=jFu+O(u=(4jNG0HJM&hPS;> zv3jp@jZqPEnXh-k>bJ>1hS8cmXsfE;r|JqI6yCuh&vjiExT&O+)D{Z01a+;KF$8)G zhvT4nB1%=66 zwV6Ytj3`+s$wb}u@!5`3s|U`5QfsGu8+GRk?yn!;r~gJ&Fa5~86?Z}$0vRV6wl2ldR|MgZlt1v_IBacgre@rF42Z(-hzZv4JvW{H@2S}3JLs2S z-_;<%(-=E}?R1Qve=m$9c9c!xI1(c}Pzg~iPW>L*C+XJ=*;O&|$}Y`Z>n*@9&|k=k z-Pb9xkO&D>)Yej9pIriFjgOa5Ir!h5AZ~r`e*XK05C6ZJ)BpKf0xLts{~6Bxe+HKS zn@aYn!ki{+Nb1!#%!$boyc3q&J6p5 zX;UnJd@&M5_lGK5J*;Y2Up>@*c~?+n2O^gQmCnt} zE5ob;hug9)a68_TDA-Q3h+s@(AT=H>floldGbCg$VQE0P7=-k{J^bV~AMJ&d^z=ZC zwi0HV6XvL_@?cX191&wcnP=C^KaFVr_Q>Fl&ZEkJ#M2ML9NldpkPP40KFcBgOJkEX`P>!S! zS{lrO5wR^4QPR0{=f<6-d>n5wxo)f-0*3L-$vKN<1D5;B>R~c}&ShqZ0>ggWke=)4 z>XL-^gHSEew(d=KY^hRHWD8j;>yneWcboKi6N_-?mNMgv(T~t-23*%7ngk%6rR6&(_w`d$6l>W5I z(~ss>D4eK11IUgQz7OChXauT1riuT9gGou-0o0aOR)}Qv@~7p?9INY+A#tx+nrmunL6UovIdayB&0YZTniC55vX>1 z_ys_4d7mY8i(Gi-5M{DFHr+J~o%RyC=Q7kepAR1dktR^UfH_J#;kXhV4f7N>2;#tM z2>K9<7MuVhr}^8rONhkiLa!ohREYJ$*nsHf7;ujsH5^j^XcF=q|M5{9L?v^V+!Iea!;CmDi(-k=v&F~D?$cy=X# zo+UCFDMaF%(0V*)iJ^EP?ZGck_5(|?<~#n|jqiP_tn3FkBvQA@s)RyGHZ{oo3IQBl zV!)t+m{0`Ctp-CT7>=W(BkT_b90Ve+0!7q%X-bd4jewc0!$=``6Tb=g0H)>NYI?k~ z0tcTTvjC7%%bK8r+L2kWN#t(K)E9)o*ePy>L{#F5Uw_lFXuKQU*CiHnX;W^@=L@WwfkjZjSL zrG*MhB#5zrOz`vtsOg?CBoJ)agL5PWEPa1r^3<+4#9eT-FXA@sVmM92Z?Mo$kdrq5 zReTssqY9#03v>t1PJW*+O-;dzit_MrF|dc2MdM41UOms8Se%~^6S{p8dk{2@h=%TF zz7T+w*T-43jg#0&LjWFGE&^UuZaG%@17NKKl8iclNC9d-St+Qazo$9p|$(tiUR+Q^C0lY}rEu=N^7okhJ=*inGxF} zq9A#bl@b^~iU_X(bwwdVTVOj($(Dh0D@Y>7y_~hL#KLRg1I;(m)h&hlwY@?UlNUv9 ztKnyhz!RuRu#b)y--mjI_e2P_R{&lZ^@Gz5sJ6~}a3Y#W?}3d4NDVD5A#_zx-nZ!D zWEo*>t8jTHNJ<{Zq8gy;t@4|F6Rk;9WS9*o-&TR)i!4+k0n$4FL#~r2ukbZMOOJwg zV5JR!8Tt9M(CSpbSL?X)kEX~Q2p8?t)~~FE`$o*LcfjmCq#qz6BEnDd9F>#^(IofD zYaHM>di22-U8p)lpQeJlYJmAOBgkI?{8eaQN@J_WI9B#@8^53rci==ix4~nmB@#%a z_ju`pCJe`y?_Ruk@nHrV+p`%Jn?g21MLfjF_yLKntMwejXwlUC!O@9Jw!w(u)heC$ z2@@YxzDIYHe*S3f^w;2t3}GGdym53u!82zJF@&c_>OhvkU-Vj~Byk3wK6_T4ot^E} zrUYcZ+>Of^;_*Oq05g4^=-IE^WhEuI@g zd|h&8W)nslj>W`cq{^rvl0{4U9x-D@ibN~_Z_+a~>?Um!&2eXSd1ocdCfyUq7a2S4 zZui<(xu0i43gR3mMuy2sK|<2zglO3=uQSE16{gC5zw_=|=*d*v-%&Fj6+`9;T}0wi z)jyoHx3||dGTH%2z0hs-mUSM)awJ(vAhAAbTvs4Kw15#JLaxr|ios^$5fZ2|5@@MV zLdvjgcI*2_(GJEsvY-JI`ha4-y(!#xa)4y#-g|x0i-|BfB#@CM%%$En-0iIo`WP=9 z7PTjl8ri0a5gR}Y1wiZGpAnc~xd?uZ*#RFkmQj&E$k>sl66@P15#t#y7HKVVaU_*O zsuyXaVG@3T;xia5zR~9z0tGw>aA6Axc7G^ZL>_t&1|wTuU)$f<{-RUl+_^;PLxtm5 z%ZK16$?`FNI5S&GLgpP0&~eJXe%*xx%F}{glyC{8br^-r`@U+FxQmf=!4y2}dyCx| ze}1;`LI8%C&~_3_o7oZk+AJO=^cA70eXId`VGd#v*$Q+5YODEvf;{9Zp@s~=MksU^ zyFYgkrXT=FyuGM@;-o&$1$Yi%xdcRmBGW%I;tLmxv@~t_8zb=;%pm9H<_12xHpNu) z7j>M_dBCa&^@Wa(4rTEvT@b%7NXIJ#bXzSGaX@yV-@Mbz6Znte2d zB67!NJw1mJK_n_V5f=Z`btZ4H{%Y01tFp2ckPLwAXz+gHxzR{@tL59DLNVj{8i(h0 z5K~kz<|HjeQ4*T5oTj3X)E9I@w&Ph?Yi9^45t=XYN+8914j#O9bbO*M zg@%YWwgI9O5lGTfQt`IO!NJvNk9wx26u2%zogi@RPY3e_*CF^=cn(DT90MM@kgmI^ zy8MKH2L!9J-@orb0ok6PkuWzmu7m*$t^4BXS_MLcaK>SFk zfy=>}Y0lXkDDqPvE{Oox1jt2b7jiW33Nwzg+#Osk$!}O6JhbvW7ZG$p0FO4 zHvn6UL`ueR?)Y!n&z!_$>ophY5c{u0K$Jdsv$hH8~RP32Mkzu-xbDid!McmBQADw~jiVlmDK>py2BF*4BMYqu00&CIan} zmK>CxAkJLz^Cu&w+ZFTt8xh5c2X(>DAcUoiew94wx3^;h{QWDSZsI> z8#JIR#H9mO4li&WKhDa!72up5wJsT6ikSeNerF?nFpTbOsci5pBG8*1;K zJ*8AH!^8i31QJ{$;&&h__m7P|2F#iHy<;CFxd;CKR79l&&muq}(EWDN9s20&J9aS~ zPmg#R8TkszyJx%lg4ry3sY!dr;ef!vc!(|Ev7|-rTU>+M+k{H;+K0e0H9a!bJMxo#} zXk~=p0lnD$(}YC-e@ojP7*E5Ybqe#XskurFsJ#@jkOa@+gI49n&(px9GAV*pgw#v;>5#1u8nldU zXvGcd5Sv$I$GIunk8vom3tUzL-4E>AqE11PQqv5~y6e;~?R zms%dL;*my#)r~Uz-;@*bU`d}PJo#Z;2wgxZXLp^Z!qYC^R4Th$>#Df{vScMxZtOt{cl9`#g(0!c&(;Cl4g7KLW zGt6d9pb<~a&3(bb6Tkqvv2s_SwI$6Zcwzx~Le?wvPvpuXM*cj@FhdTdZ+O@TvRN4% zB8(9D!UXtl-ev-2i+Cgz;TWKL>tVDM3cM6@$}?*MwZt?S#xcN4u(5X2pSerRaOSQz zZfZPU%9AD?PwWVUF?9!w6v`nx`y8 z?g#Cc#m&@ zF+XIW)0W-zaJ(+#(y8pvr$Th?>&~+yoHo=8Ds!>RI2;5QhHOuW{eaq4sBOd#=h4(Z7b#8~s!ceDg`cuj;TQQs-=S-$q7^Rm`Q#N1BZZr#0@5% zQ}^>}baVjVFhwk{Bu;jcgSj(rG+6Et+{?=f3Qx-`>1~i11Nx@Mjnf&%QN(d(}NOB;+9uh%N?Um)4dA@#L1r&<|Nz zShhd{LRX2AY$78-CTpiXjZ0Ui)RM68gH4H;=6o(zbm$$gg& z4GnF{wb6&ICevjiu3JjttKx&g^U>mGcup{~K!lFchjTL4`GJjiPZ00i+}xp`^&8oe zQc}d?!+1!9!bs15$9qoH3`zrd;mKT0xPM4CPR&oK5|%6x_rc&Xp5>>c%lqe)?9KM` zNP=bV)fy8?;J^(hkUYY#vxK~N(UZTri(}^1IySQi!Z$(a8`q6a+nR4WaYrzu&$x4aAe0qIzfCp+4 z{Fu;k(@@t&D?Eh-wePh!YSAs6%9Hz^Ai?!VcvU&vAPT1_8`YAZ^bUV=;fYB58#_Cr zw&>znR(k?c@L{M)2w|zkhKKxL&wU^N|27KQl#$4YNMtJ{va?f^QIvFAky%Po z;w&RYG73dT(z2DNwwak#D5;EwmBQ!ta9!8&{T$!p`yYI-AFks(&XafVUa#l#G4A*K z6Af>dwm&_JJIj0)FF;*fL;9>Nf?vU{KTSNFCaf>Q=2 zRLbs{wpg6~x)}L=CfH>CxMA7TowH7>b}tEUt~kPUxBV;36xu76tXOf2&qyT@^bM9# zK4o`A@RJ=}xhO%-pv6{dlS+|ML z6cB+|`s7gsC4JGUF*E657oq}vH`RC7uEiXi2nyh!5#u890G?i~`|>=G*jV3v}qfmb5^-HQ5}1Kt69P@S*&W?*Y{T`&7p_p4e{737N? z{Y7xn2PbOT;L920s@0{-=#sZBF$A3PjPXVbs;+j54b0C!+UQYN<=o^)L`UBFmOsCP zN+@k?nI-f1t~F`SWW`f*tbP?WSv$pS3i3Ob_E|s4B61vpeD5$|@xSYyMLi7*hu&Yf z@jJDu|AD28(~FH193&WLQ)eyGU>Re*U^`+1aI09pB8;%wBi)I_lqR3yz1KIu#3?J@L<+ zuujiI$2C@Vza3<-Zy=TI0LtJWX7Ge(|1_5qP_S z-u$acwD@wmqDBl`{#aMp{p-i)fg3xTR3_6bPdr_;F>lC~+9g1Ox%R#D%5251lB#jB z;wg#1ASZLf%PVV&OfZtUqtUUW!QJ4+2E&F8P*Is`r*wY-RUVRE;QdE8jF>%LnsT;+c;P)52q=o4)!$fApszzY~+=wle#uF9nLN!h(>hI&BWoJome1Nc@DO^i`ZbK9A4hM(+GLF zQ|8Q>Q{H~nx7g0M*~K}=MYiUVOS23k`7S|KgSvFtO4mB6WQ~3=*M|i^_GLyXXCQK` zd^btwSYG3k{tc?6OO0qZQGGKMdhcNIj#ESf3|FPtZ#jN!0q21$_qEhKdEMmnR|p9b zQl9IaqyxS-a$yZdfMqxjMvXqpEYzBe`cBM!B7m|bKCA#6V?ZpJ`_uP{@wg3ny}uYI z68~dJWX0`vd>BA$;^AdBDP1?}wMp!;*r9LE#!i_=BiMVTu3s%i4VF~sVlcP1qkTEo zO_ETZg|L4){a1Y3;RKt(Sw1Wb-LOQWz^$7%w;_aB;NWQs0I;fyg- z9Z5*0N(~!jIr-&f+}lKzhs~?Jy`beB^NqQ`_TPv-r3=BLJSw_>52K>)$4vKY8O7j| zsOE1S@6_d`HHB1BpXCfmA_CFRne=Ahph1NrCi?F+s~c;iKe&JYMaUd<8D@U+WyzKw zXW#2xNhynZR&doequJ7g-n|D7oOM_;Yx(N^o3BRvD<)iZUi3GUxmU&OzzH-JC!71K zya~CnIR5d6rZ0#YOB4r>9kY(i!8WCzA_5#m#jxQWw=J{31kUT0%jdS!)Wlip;Xy10 zzihqSb*YYZLi|Fn1YoiF8hZ!3x*^9xzjw~6=l$YSaCHclVg%5GOsO&k*TyxoXO>IR zV10^di+)yFHw>t z=PM>W^^!?I+MQI8!>>^^&yGNRkM|X(RY0SJyl)n%iz3|nm zR|Z2qWW5XIx`?pU!_J=}A@QjbBepx9>&A7^2-LNN*L`Tg(K2`&*vmM;W!~=eKCCNN9BdTVc*~b`F77=R37nB>pfu9 zkFX-h8)f&r=g-%kn!Lvaj->;ITfHPe#$=3|)=t@ZvWxYr#<|sD(drs7AZs@tDnPa+ zvBUhV>9}&|1)%YAJb47UOZ{C+!FQZri)Wj-354&+?zT=du<&fGUgkvId6#$9Sty#L zY5Riafq4{o;cIOR1$gV>L$i`~y{cxNN;pDc2lkZ8y)tkVN948B{U`Ws+E-|~XGOPP z5D*7p(CQ`Kv~i3qgSY8;;lc&mQOUl=cPuJatv35|dYR{25UKX%ZxI!1)l5Ef2gWK} zY&FT$Joni>6DX7hg0o=&t96wDN8l1o)ce*5+Y?_-CJ=dF)Z`t67}O81 zc@iGKhD`oBJn_Gxrm}BY-AYcGT0Jp`Fdan&<9G1jzf>avAy1q>J%fJh_2Bk-k`=svzgUw_3DR#V0Jq57rXTBjRkx<1 zQb7YNUqt+TJx-kHCi**iik`s-JiuHBV-DACz<^dXNT~!KF&2`4Gh&1WAoxwVxS8mN zlZ&wB@MO$aA>?iaL$Z*cMc1u31z$csZ!3CD;`4EAR`@yM5+)Near03jxNC0yh@iec zCXosSKUGX$g!`b%=?8u^g_p0Gz_E8yGTViZM@Gb#-c;qkS`(CV+*J-L04 z9z7Ip^BX5->n{i1kOypkcyW9uVAQ4GyXKMnx`44cIy%O`C5!Yaskj}aV}>d#Bg(ezCRu&Jh3$xTD{ql8f`T+)ekRB< zy-v{{#`TT{wKXnbf$r~+Z3mbUg7`Zb00n(|PARd`x0Db4kZ7k#Ul?SLA~ha%k= zyuSQZkLkeON;%_=zh(phsP-D{aC!hvjQ!x&F}Q4pgZ8ai^C@yOK+(kK2#8ig>>OlR zg-~rZXG8X5+wOK6)CfM(nm-Ltr7M${R5QewjV#7 zpc)rKUVAu#=f37XAp+{gb88&7qPHJc6s2bBy?ZmttHA+UgELiNVKMEf_I)hugkKlM zg0PVvzx8)9{BBm( z$Gvlz6Xx%5jKK}1vQ&S^1s6`wGJgN!{=n=5n$k{3RW&IMkMtif z;LX&1i&gUmhLQIb3MkR;?R>S?)_~dAw}-4qHaSa8iip<9vj|0e%U%5$JRg{`E~mqn z`7<%~RCHc>#=oezcn5gWi|gxmKDhVjg8g)``%qBMc1CXT1&U`RDEEK>HLFCorWiM% zAuZ~4z7=XI89K!Ji3ZA3WTd#LL|?w#riorFZ`nT}?s}w*(bx^#<0AoJ#%}W5H)z`u zBZs_6r|gZUn^}%;NhB-$P{s`v)CT-#H0_rD?z~o_YC}bo0+I$yg_- zUiap0YZ4P>2gAu?u`gXae|+o$YS*{H&!Jj!)8CWGz0@^^VplD?4;PoJPO9?x%n{Mf z-ME;{QIH%$!{|hr0@l8_Y}hteG`R|5__^H5G`|_C30g+`p1jvPBu2PLwDl7AO}XpV zRCGMTJs}l|VX{nj{?!(gDXTYb%-}8p`9wUtyjN4}dnW{1$<>voCT9-cPI!2akb&|O zA9HV^tO*|crG|LpSy;3~bgMRPNBq~G%V0LZt9Ovpp*)*G-hw(y@&uq3d6c%He}n5j zh5fU)bOsngzF?0TFi>JO!wwCj{5O5} zR37$ioEA1^^H}3UZ_lkDsu_fp!C#9rIJB=H_%w34H}s}|RYvPd>heo0?e+-#?`y_#EYET5*`p|($4z& z?ND3@OHP&&y*_yMjJZ!tCQeLZs1}_0qlRYa@_+tG0`&xJ5y31nm#vLj_n4S5thtJc zF^7{kwSxyye7$L`*4IixWl>4@W_Fb-b~*dP4|IS@PHH!X1fpbprkYTtC_wyXkYG`- zR{xKLt@ss37C1N$E$^ns58eCDIggoYV=UGQ*mK^8e2X%>`;}Kl7J}^emqYJAQq`Q6RASNs_obWQ5vWMYp&bqF~B zo=%27;53UUmznG<3NFsZQZ7}921vc^3mJs`VZMNY_d27Z>oidaAP$9*3KhuUSTf8Q+Py^A$#H-Sl>M^g01wK{ zm8);v;WbjajDr;9YxV6I{}Si z4V>+taK2&f)w5V?c0dpca41A6uXlCn!(JkE+w!~GKC(L+{l$?Jg}1FHjnmc!UC6!ilSZTKkIVfO8pGB#1bp}8D8j=C!N3^E8?J`VP<@j- zo9!A8Pi*a69lUtcx3$8Y(TZnK=vAPNN3hbK5b*T!$_~7smTGF<``qLZsU$?~ZH?ls zWPQzWzZbI>TTk=bJ9L5ZtIdni7-nW?r+~;hSx?7f^55_8Y#d7VZj9`S+^Qar)svBH zDhQdTy;Dn(T7JN0@uU%=kRz97q%_|GzNrj%fckh0&Zqnrx+nx)ZP!Nh7^Bzxzs=IYB>ARh#XX~D0DK=^~>0YtQR6#<6GaIOuv)YuKi?vu~}PX zH`int!X({?!dZzr13*u7qf{xWoIGYl#-WRW?A<{RmRz%E1ZaoCo$xy8^rX_q^jqnO zHIxGPz1TR$*x0FP55!tqDOi|{L#X8eU)MY6=!gB++V-4x)a=p=84VZLqsDD?WeJcN zxEN$;Evw^6!7LdE(%HFym*ecXI8kp>!X~qAvDIs=3hftosCLBP!$WJR zRD1o>odSAJNwk0}MgcfuVOZ`;Pz10w1VFMT0L9;+l|2m+X z6*vu`1i^CAddF@h%_f|gcZSw-Asvhmzno~18OXvB7#Z13RN3bjk)W=^E5FcY*QDM< z@)ARZAfj>HN}@Z#lMJ#uI>OmQN+}w&2WnwDIcIcgJnQB~=768>f*0 z9lY#e68fjVZg^MX7<<1NW1DeuOL6*x2dz2M85G3zizkd5S4Y7d>zdtV#AAWdu$3y> zaA*^T!(bQn5P2n|)Pr{W#)BYNHe_(uAfco7x@Ov_KDj7oa6#MiGUN9r-bnW!DccHP zn}KR51p?-eqYMwLe#Nz8s@zf*+gy*2ZLW+Wr*_^DeUnaCtsyy~jS6mY2T}mR&3(Sp zu%eZJZK#o*o$bHj>%QQ^SHPF)lV!9S&yO>v!`w9&92T`YAnZHCr@vYBak|XPnPq$z@6 zm&H!)DS_rfZU&N`!0!b^@hsr5rb5q2SEWH;80b`G6bDE>2GKdZ2s&`XDF5Jjsm1^^ z&_qrTx;Ni-K{GJkN}Hg(uV%nAzKuu)(5|I`6$a|nuFz^aR$Kv?MUSeDnDSc=6gRW0 zme#h7dj_VT+}+l=d4;N>Fwf3dD5y{~{AZUB?>{#x%w6+)AH_cpPBqp0j5(mk!WaIl zjo0@@?D+7^%zD6jqb)2Bn!XHwC&e)6{=9k5=XcP!bn#-L)jV*qbZP?(ycD!zVb*(C zn&kpxbT#eYJg`zuTWrPbBWPPPIS!%R%jv zKbdBsF{;*}5qfd-XRoP$MMi-%+MDzq3hcV(K_#8KOxe}iI5^rd00>S*WRZ!CI(6!Z zt1N%9d!Kz)S67cNcMfimWjL6?!0CT`3zT|i?(Q`{XE0)A4%inyyX{g>TL){4?8>UD zS!a@G=7z0tb&W4@%~}^??_E7Eb|tuByfZ)6i}6p&hdrF{?9JCa3ln+uZ+RVzlg+wb zGdCpGX71d>7l)foI{K6O){@nlv0slV(Dfw5J}vId2QRE*yh+lEiabPn|!18?ZGG+<|o)!6Tggy~lQ) z&D%T>6a*2cRj*A>hw^XSaKC1CJGO1B9qdMgqA(q5^@A2^h~t~Fe_6|&0}-V7p{pLq zH1dX`Y_&TBD0S@>PkFsn@w-NFJZ(h;%$#4*ZMF+6Cau^j4nmvR%WA21ig%@@*&?f~ zLQB)V()M zaxqN7HHR%wH}#{=esFjdpQY$WKZJimR=XNULw?YKItx+F*sw;XxKAKBDY|*xgA;;A zZa!RRAwr$ogLgXz1iG5cGc5>bZQ2reP+?c2qx;<%03aA}q}e!TRR%C|AgcNQh^l$HXkKV> zeu0S`Vn7Du_yb~?czwP^PF4qua&Ic4k>Feaw4e5rGZI4x(g}93H|jv#=ev8W>|(DSAD!bm#7Y-$xe=e zh64up!*D0NTrKRh&+xMg-a)J`68Z-n_>tCDYo?3Fr`V#Gi>rCozVlJ<5y#r~VEm^f|blQf+ch9>UJz+gqwD z!@jJVbj`x*(6skXYD4|2HiNsW=(T1+9>P_p2W(+L004qxa{8>P?@GX%dPz%Iqj(6Y zZ2hT_kmR@_R4y45i8(uKh^$laF=0W*YDJ15R0 z&oqYm_1GfX`+*#W?uS}ZS6NFOv}8w}K=|Ze)%kY`ZbSiuP+qygZPB!t6xD?Pb(>f}h~67TgGKbH zVdn{pxm6?d&yBqS#P%+zhCFra)hnyNgz6VB+x{Ea>9X|6K>v}4HAgrll`J`X zT-it#qyiZLEFqj81vo0w9f5~_R~Kti2HDaH1-Ts~!0w^ye@n6#k0d67q!^N3fVuE! zN`Y zT6CUj)7}CA8Pr>c+jL)@9 zdP(@;O*equQI)v>nqE34FS3k`j^^hZ-Ob75b2| zwG*VQi;j-$pkSF@9Z~A8O3U7a%_oKd;zNNjZH&WJo|j1N1SO!W5;p`KV8wJ;Iy2fQ zK?v}HsZT-a3Bky1?``B4I?h0$ z!foBL=Ll!TgHZ5H6cm2AiEB**=E(ERvH z90lksL+CwZDozH^c50T%RC!v zKDJ%h#WaJg9pQzB(1Mkhm%D?+i-wVI@Qu7bmhB~Y$H~4O1fn#rNwhsQL3Oxd;xWn$ zNy8>h5MT`&hco`{XX`h?`M7J^J~l~cA#4zyY{#O?bX5pXV`1Sa1HAZ0(WKLIXHfo# z8}O9FgSp2mCsEDC zESudtXhWi7Vzx2>PGQd!$30eM3S>%EIu^#|KED*Cp}MH0Ve`b#EG*9O*CSSREaI+8 z%;0A!fcYfYKvn3mv*Y6TmZL6hk}Mv&tx7b?H)Gi|CIrp#B$7AQZwUIyP3uRi$M>6k@|LW^Uw9 z+(e^!5Tc`oj}tJy{^D=ScRmJ z)ffOd997o02~ck`*=GmlIx?d*S5LkhyaJ&$X31^!R~Ne4l}+V!zh@wD1Y=AqSMG)v zq*7KV^UItH>A}F+3h4s$ngEs$&(CWMhviZI`SY+{lieCMX|li51zNf{01P7N7Z!+0 z772sy6W3`IVSQz)K$OI|DKcfQE(jVU95#g9+|`gj#&jw%_IEo;kImIgWdW4L05(F0 zl40ZZ0*Mc@rYIC4m@rCwqS$0}q6+wKERTu2xWllEPhn2AG_OKA$_+z@f1JQNk--Kc z#9!+?OwdZ~jZf}XTfbA=Pp_uK^-4D%({GI@$g12}n(6^`$bVOUPMvDWJjhepd8H@& zkiUt_f?I=>qJM8B+#PVyo;+pBt%|89xtJHY4>6qbelZreaO)I3qXi{;9uNI=`G$9k z`IVX|p5Jh}Qcz265X)B{NQ4*J8u>=XSUUi;7VoxCKltTb25t})Z;)>9O}KQ0f1dxK zyK~(5Z>On7Y2GX;Ch}-QopTn%JNdCU=9q!Lem{l@L>PsJ*FcqgVZgB%Ur z!)Pzre9hab6DKZ&RmQ<#S&-h{{fxH0YP+WuY2A*%vhP$LE_u~If5tL1(Cgwynb}py zcEvSKMd^KQRhJ!JUi|^IXA_ZFIh28Fji#P!QtrhBMd-NE#Hpq37>4{{*SVKy){i0z1P97We> zxqtiur!W4a_N8rauj=h@NdFeZTaimiC0X_5%VH*N?0`mF*Df39H*^7#(QF#)L;L|z z$};x3P{ir4mV14(Mv1$2KCrF;`^(Zswd*WY=GcP+;*vw`U>Th?gt9uZ4hn>53 z->x|4$5X(4@@EY4*8yb4B`ABv$(Ll5UXwB$0iVdD<@NDv$t0~TPs6G9PL_Su&cb$? zHCb_1Rg)KOzLwVRuXYkRd$Fh@N%5ks4nXXQ(b^6qZoYx1=)*+JCn`T#SjSnKc``tQ98=NeKE<=7{^v<~QTc#BF>o>c2xPLl zRm$f!?b}-xEWPt-PHl=`jHY!km)M+fLi2NNp5yX#vA9`YC#5ynH3;u5xs-<}9b^S4 z*BsY7H;?b#%=bj;IIo#z*R+fOhyipL@;m3{<UKSMM*+sOMYh<}(_4YGeH9^Vrl!S!qG30PIyHYAV*)ETqo9m73~#-a7H>(R_R+V_Lh!gC zJyX6@jtI#h;!oDRV2ddMZ_mqk=Y4?(YIUE#JCF}A5z8b%3^#NDBGywUPNcF;E9(FR zl+(z}-gmb!e)Z_=tQ~b0ieIOMiiGB^!bDL zjhJ$kY@R)5PN1Ae(5IuU&OG#B?UTD-yt>&M-MK611aI^2_hpr`q(j^(`uF9izmWAvOt=!FWNa^SUX5?YLhC= zUDqr%RuSufzhR8zHB;MbuW>i$XmMpUwzOU}NZ&tnrmC8NY*;q;qS# z>bRoW*M0`5E%7`NTlq63>28Gnd~NDY8^u#%3x8^x>V#`IW(3YLYy0kJ@aOx9vNr`F z0ilyuY(mOEt94C`Qg><<7Mp4u)vTTJUft{#xPJTc6XzUFUGINCU3i80*`0(c#t0kd zMh4T|8Pc25v1gDz{?bTZ3TU;I)?H#vw0ESRLw1Q1n&4p`a zaaSF{u#Z;0Ell;#^qt~1nBXqna{cIuLZ_*8*|yiUhgqN$eBjUu1&zhzMcUpyx2CVw zN>pQ9U#v)|*2V>^8mqjZq`U_xxsJ1#9@^DtH|p$AX*8g>xMV_<@o(PnrWpxO<7Mw0 z=8YHE8u1}UzIB2Ofp#t(##bP4P%|3_{~`9xu*%t+-|UfBR+M;ZOA*~;`q3sGI|k#L zBj4~bGo@N6aA*dB-<&|8#4H4W@C0}u7`iPVQ0T6%k+o`^R;&9RRu^7bWNVA%o_?1! zOOzLa;m^k+hEicXY#qg~jB27cVbCk&%)ZPEEAj@wnixD;P0uAxPvu`)7*Cv8!+f)l zGg6p=)`0BGW)11WXlT^dS2nEfGjqYgS{<@C)4ekuX8j23^;08+x`!w>X6q?0R;89a zI@gvSZ#&=fHBF@$`w_eg&^~dXb0?^!1^wGU-Dg-!9+O*B_}>%b1FRqNpIWqU--x*y z?2OgcuklgCi)REo{+_IcJw;p6=STft+ac&2X=uc|4&YL^w6OM@jjb)hr+KL0b&`k&iaVKIhN|MwsJ`yT>k{`GzP-yiz>-^bd?`sDxD-=Eoju>OBMTe zYc~G>A0O#&?SJ~THq|p;)=rogROnn3KQ25D4b}lYYnk;t!RWU5 zWqNsyo2h*+iAq&L+&C~iADKRpT=%Bxuj>M`0AGZH|BK@gD(#utXOrlHPcTU(9$dT` zoX#^sr}^gs2Z*{!IxmvjyW4?4_ zk`AaRos~l5Pz-`LRwN-Ck!8wa_JkrTz6h{vEA`h2dIb@DC<-_+7MNDO_#!$Xv8Qqy zu^O~z&$^?)M}(=FOkN291z=O8%U4~%4*TX9G>k^pLk5;5GDg3D)B~C~`4CVW@XPA` z=g+G4lJ~@Q|eq?If1rWI#h+x zrWPYxbS}CQ`Q_Br|C}4psN-^K?}LXA zhmtk~@4|@q>~)I#W9)^B3|o@xjtF+J9q@9U%Tf zM+Y0lLUxW|`&d*ZSrk2&mZ5J_9IjAd@k9ohK0e2f{OESjhqFP;}4m;EtyOhkS` zZe1D(tT-R(e&@rp5D-v$}0o zPt(Gys$mF8*&Y-xSwv>$`g9`?oIR(7=v(HX!cB~zm>bZckL=1tPizcR{;TT{%@}X({P3hcGb@e z^7TOl)-X8b)DBt~7PuL$_M!rlW)@7hg3-@+>grDW3nFZ6wSocyq!lK3&1u*uzc{5+Bcb zmSdyn%%C{50#NZ-WL5wy^$Ies$)q@rhq%az@JKSqE;>(vvK$=tz@UjS1daYUoT7K0Opf6n-!t9@ z{Lpyw%UE1*qFuf_z*}T(Za;ZV#~dBJX{n|-2!Q%@t1~xaGvh{hhp;Ju=P6hwp6Y42+1-UPbZuK zC50$jx9Rd_%Y?EZDNq7d_mVo$s7{EofdZ&I8O#Ker0tMS@6_*71pq+dA<&c@yr4*V zLWqhjY6|4~At)$%JBh&~8m|=Upxw|lbreFF0sY8azK|K@cl-80MOGtdI?`liD(r$P z(h-^ur{+#j=zCCkg#z73Q?Vy!3(?e}PE<6r5sDUWGmgwjTRMEJW-;KT&gqsGes9h1 z`J|39ch<7D1YtiL&yW78WfAj3O}v~%tIC5DbQ{N#W8W`5yDL0~Pz2KZPTE3q1p_1H z8_^RTqho@Gu0zv3?6eT$XZn*n%U5#^b@^F4js$xK+pu<|+sf5^4I=@5(e%(3OiG1{Mn`=siSAKc zr@S7|nIZlq+%8#lD5@6wSA;4F5b<079KA9hZdn=+@MU?Rxekw0Qd1oZGh1@ZR7Rem zHJR4}&Qw%&%{d=}jQ zXOPZQhJ=*{g@%IDo1UDM=lrMOWp&JlzPi1A4i~R&rtMf7oELXtLXvv&d(l zbWa@V*F;Scc3tEj-h^_o4r@2k_3G7?s7S=d7CRnWx^k1IGtD+Xruq|`DO?|`#x>A* z`{qqFWf6|6;^b|T_y{rY#EBDo?!Q>_{=MtEvM(NMKD{vDG#1TuJhRklzY8p{bYMs+ z3oW;5{c%_T;lj|v42$BjvYp7Dk}2#6v!)hLU@I>o!zpnB?8#tKNQ^!iEttM2m}E_P z_}l!q!z3+#K>E=-0d5bv^>aNN`@Pd?y?LWuXIobepS~!Yx|(rvr2fTEZcmbD!aqQqV9tRsS1GtjN2b4MMr1aRQU%Avgsis*KxgtDiPxk z@>bo<@OguV3`s{-Pl12-YutNk!wN7#k>7{xutAsJpSq9~KI(%i^#4v?aSE1M%pkVR zZ%yX>S1u8!`JddA-WEIjAvnR}T7yRt?$25nlM+{vbOSP+WxKHBil@N`f`lE zRzVvoRTC>C#VC|RbdPR4>{UaF0kr$koJSzTPLxYBIY7Mc5440P?kzZgxNjoLx}JbxbJwyI0NPP18h7ZYE{;rCPm z){uWu_SHp_d2`{3`6NI`L`={=vro+4O-Ta>`DdBc?7}hEr(zG!vCv13M;Fly{goKR zSXYnf^aE4ZmLe17cWnUiu^s-KJA9uV#$F!@mcEF`WnUx32_GjN1wfrQ559NPq1u|? zy4KvY;y($eEFxWIqNIT#oxgu@+SfSeM||Y0k)K@QA4DKbu{DZukg10UB_ks~9u?&) zBOMqoNqH*lA3ro;jN=Hm+IBQD>Es7_r>rdqCDFcrekIEIp4rHU_^=)MOK%O|*o;L0 z?FOLHulY3mywL<96{+6|+_FH6>9Jj1#%%9-qcat&2Geii;$CO5n&}RFuZU(cu%(jv zA^X_S@B`=PUu(se6y32O9yMeuTLg|ejqi(AxuPPOx8mpTUl-ItWCC=XI{Qs7en{<; zT|B3*yA+&ssX9^fBH<5u&W<>k(r+zbdoqmtYP4(wn@*rK&`-wXb2atRl-6%(h-4Rq zybbTM9$?p@v>zrnqUsa?Y^8G=J$cPymJgyj*ZlG;#`Hv#DXkA?fUHlGB@Q~i0YO1c z0Hs6-Mi_&e_Fm?H5NxJ+}9$=a_ zWNVf4mPK_k(CvDC*|B|l3dj=B3fNC7ecjEFZc3x&@nL!A1+!9E^Lw)jtEIVC=~70* zoT~PjYc^j32#v%E!n97kJ@f_aWzpbuiO1U_x(=bODeUDvL*JcEdN2<@T48B8N!Au} z9-71CdaW$FTc8)RCLH$|5d~nNCHKd%8n*H8rh~(S8o{#8y0(h2vl^!0!?R0`0fV>l z3%1d-Y69qufm&qJgK$to^5tt}wE~bqd3A2qE8s%i@r=-cJ*R>D4eGlw(>J8#8#l)j z&XqZXy?%am=vr_S3!zQhJ5fvJc{(*W?S*_7nuf;QTxP;B*6MwCygic_;!ATMdI6!D zGJvYblpbRU31+jSiM~m>xu#3SEX}dJV6?Ovg9oc&*m#EVO1IjI3H_(PI~{-Cr}j!W zmoaUt2k;&V^Ka={|Mm9Xif18sa>N>k0@A}G3P>5d1k;|8=(bs66Ft;@r%l^Q6v@1_ zr3cUX#;hwb)yHs?`0A{zfr-slM(vWGJekv_pdB+t^cpgJ1X}j!$BWABKIWgBR*r8^ z-)uc4iRC93IOn6u;nX$v5P*rF+P&kWZ7l2l$iLX3^Z-NON+JXY_zmIWva4tzC6P)| z`Ug8hl&)TjRu)yy?0cn4wNV&!+d^t6evnt^hC;53(Hw|e&w$fr0VfIziz+|lZ$%Iw zVMV$?Kp%Q-aV11!GiT}h;ngwdV0;n37f@?SG?Z+C8(56}x^EkI^Xv6`F|_*oypd^J zU4z44N3T6)=yJVk);F~JBQbH>M(i*~b4^#9`Q&|vyec!u?ig0w$mUl3SHy80SYros z@)uiz^kZ1l&VcPr4Tt9JLmv9Wxt{4@WTCUWffi1a0p1(i_hzgixwOBb9&Q^Ei zIUoScGaY5lh(2uY^vV8J7qSyRFc2d_5uvv@oq@a_=h-tlH%EW`o!Bkk7Uf<@{BHNB zZ`)IF_)_}FG$2|hak|rV`!QO>VDu-xq*Pbv_I98yk|9Ou?H->$k{h?9_Gm}?QxV+L zc;c(SlgdZN$<-4-pJ~;x<8BoE^6(_EGIJ-@uJG%$3pR%;KuF%_Cni42PXUBoBr=8F zX|f1|2rVrL?XnnlH8-kt7xwellHzxK1xZ|#!JeQStP~0Xj;Ha6`w{$=2tl}1J2gDh zfr3PAouxWEV2)Q{G*09=nH3OdWCDe^xIep7USbB4dl7O=)*)KhlJb{52%huOr}pZo z$7aTPRcpU>F<~#RfIag;3dk=u?N0#i_oSVjQ*xUmEdQP}wJ$K0YN=88R*huDN=i8D z|M#HV?h*OtWCScVbxcVA%glXYzAl3|vRrk=Ym!y^h8lmrB%`r=gw|3<;vCe!d|J5q|`_`?pbp}M%- zp4r)=OB>zQuk(lz^Or8bt_DP*&f>3AfH;$-Ct)E1=-J2C^OH}n1$jr>ae_u9`qCw} zdH*ikX9U*T5;YcGLNY3NJh+ljOJF2>rr+XJsU~qBzzVGmG`zqfC=X<5l42CpS$U3L zB9W^tWoWA%_4l=X#&omT_)*1*rJ;<26Nmsjc7WlC8^hT(vtO>SYnfV?fYOnaDM}le zhx_t=CUV0h`dZTIHmX)1)m|P%_Se`KyYr#OVQxrUa8GCkmO`!DiSR7t7%=E|z6`d$ zrf1)*qo`NJZci9EWXSB&NtllIe0Sjk9~5f#FP`RU$R_UN4J*G2HQx__gU$^-QtZD1 zUwMV-L-lPXwJ*Ya$7HnU3m^^~!p`66P0}$>u=dyYd|9!;(^69j2Q5mM?T9#d2rx+w z#e2e?+aiD%<)Z7i$7?D>c04&huK_+bF{>kQm-`C1qnB7KBdkJx@GGEe!<3~y7q(l1 zyFI)TWE&Xab|x-myq(Y|H(i<~3JSHOGdbX*z@k9iL0iz-(6FPdK;@oR29HPPAoCCi zIJhV7d;&zU%&HgnO3%FWR4F2B|)1ay-Af9|3{KNdNht>({44koL$8H z7%Te1npc}mqyXVef!8;31>RA(^HpT1w_U%5?xr)e3RT6-6)hIEh1eowrZbz9V_#VXPLuHNi>(=!{q}rszkR~z_$BAY&nkPu!xh76^ zEbZ!=-(5*Tzx(Lg%9<4{L+iX;e-c{`q<94<9z*-Wii;k1TwY?$7@#&N&!`Nhw}4v- z?Xf*1qziaZ<&+)Aj%flTC{vWmNs)b9Fk&o-+)Z$?;kdm4c1=j*1v7`KgmdA#4^LZq zbOr@%AezPwt8ZmyCK0+Fk&3aY%(`~u^XJdUboDFXW~ZSs1j`7r-&8buGQ$XRtbjrP z0xupO&&WlWuq1cf{YigjjL6&M!i$giO#|PLp@)MNzMMbQb;7 zHDwJd#ERFc4LAsOxaw&Xmx|gb$1xkwX-M4ZV4`-J(&DeEjUGzc7zWEl#{0y1p!qcf zTfE2m7^_@d_DI%22>HT+idu8j#sE0G_RE9mV2$Tf?%g^pA(G*l^yfLp#$9s;auPcP zAUL__NZJHmf+3#K`NV#CZ-9z2Ybr1u^-mHW3%2RIl0hE>wqK@SPe0h-w&2~n?Wni9 zQhp=yyNM%aG=#q-n$PMWgJ$Y;r@tJW^F6Bsw2M6IZ=4!ZyMHsXTR{>hYPFL0hv=QU zDh+@Vgye9PiyHBrEH1{OQrAdA0Ku=P5Q7RT3xVec9BLvzj_ee}BlTJn=Q?q@b9?I+ z^+8~iAq_G>xNasX?R|8QZ`>Y;ovd#tj?0)L28`KUFNrorY6G$CHmWVJM-Al2jE@v8 zAd55N@-q$&uv3w^$`9?#%f;L!nee{I-ac^osz1lvel{X|&2Va~bkSwb57ojaUkP~K zNJA20s5=AdE~3&JOBIM@qz-B_g8XZQw>&Xr@r&mb>vXIMw07G{F6BZipEdIVrC3^< z0Qr{TD>BAG|J)ljC7=EzhD~Sm=A$%BUSU~E>z&HLo_Ga^^;(E$fb2eW{P!c_k4j33 zf@~ahs+bLuxr9 zA#~AEOb&=ElbJ20c`-rgnO(RG;&A>9WNz0H{dfJYvj) z_Ld?k7iB7P#E~x16+$gi$u1ty9tp2ht#ZkG_3PNFQ?s*X;j#VgsmBVSvw%TqvAg2E z%!2CC#GZi240fQ?BB>RRe>k4Z%0VLpDT86|v;`b(bigiN_ zq8%WTdJ1mL4p>}~cg$|6e|bfo>Di*GJT&3?7_n|b7dYF=F*}FCQQmw-K6TkmB&-ql z)WiFU7?uyC({8n6>%N_P_9!d7IytT(`Eab;Udj4Iu6Wa;QQj5ExgT?e>ODhT5d&Wu z=5l=MZ<#wLb)?mujnzWGf9`e0qh&@DefE`ffvKV5n$_Xqf##MSTD9^3sOi;oqx{Wy zm+#GJBxg75%=CxSmfDI2K+X{lyP8ThqBn{SjxDPu@OgXjGy#YV^t~{rvgZmov7p=hA%H*IMvw zuHa0K#{`jYKZ*ja3#f*O?K|%V!tM$q!Ov^sq9un|sz+|lpifIRsK=;nbmq^ZMiir9 z0k5GrySYbv&SwTd9CY*DYNozUlYx+f!#DQ5bJXxiS>eZH9F-El+TII7() zw6dT#%OeeQFvj)Z-}JiiV$1iJ8>N(yTq|?}gO{=*h#Gb#V?qqO&%d<=C1U2KM^|-d zGjNU>#V>lZO%Qo}JL!y{b*|h{RC1zqCsOzPqD&v_m;L$Nkh|W(^ zV@OzDcJRsvnP| z9hf&crcpInS#X;leR{V%-)K zob;W_k2_go11%4c0#hasY!02Iy9ePvkY3CJFyl3AVoi^C*EV-$|4b(vXQJRwp!k^;>sEc1I;*Ur=&?sJ>Uaeu-U*@!;f=$F{n zOg`&c$-Tw5ZDxq0_}IL=d#pJxoRNg_x;6Om&UhR0)hwVbmJ8Cjab~=4ysDh{jVs@W zh~+LIHHAj~y#dF>%tGL4&5b%`!Aq#VZMyY$Hk4c~S6FhzpaOLm1mkPJke?!qZBFyM zo4@?8+5>SEhbt4miHrZ-uL~ZqV@i)ya6o}%!4rIrtwo$IR`6g}Jx_iFL8I^-PZ%pl z0k$&-9LQH@MmQh1GpyYD?U{0{TTSfuDJyeNHew`3PdVh#Qg~5NZu5klZjI2F^#jVU z7&_tio@hi6%-UG!R{1heQ&Lm)a%27?Mh4hna5~%FmYS{q%4dVTlG3R{&)=`vGmL#$ z(2j8_-$OVWe2sL#h_sNxI!URi^(Z4#`0dm@Hc_$|lL&#Ov=WStaepPbOSgaO(o%fe z$M4ZNJM&p#>LimeQ?XEZ^ZhiRhT}tjk&IC6rs8t|NB!61oyib@67Tg6qh{YN0BZ0W zlrv$VxI-APl4e7A7>i5G)Tpi__a5cI*~0Lu|LU~d^ok9RT*B=Ae1)CoG6Z^Zs8tdT ziFrwTyNro;`?ZOJdE%mDba_5&>t}y{Lbei8C1~X4VRw}sgc84WY5mqM#Z{)b;>k=B z0~`W2FykPb1UM9^!F%u8-nwJQrW7*)-3DZJm^+S<3fK)IgG!G1O3n7|&#u2t=JCcJ zSc;)8KWZ3F7ZiHb?0TKSM`N0O;}W!}Kl*i??y)6WHTzp#d$e;&;H-_c45EQx|4%=mf_ zY985eK#kD5Y1sqD_0*0|VQ^daXwVpu}Y`>$z|)>)r0(pBWM@ zl6iy>mZ2|zl$`F==q!wf}%Uw ztGNAY#JyuoOng!XR41B$tv`6kkl!oTe$MROL66^tR_$EVg|e2ACg6aFj3PSJbCwnjdZ7A4mw{3gcGF6iGU`jTlCvMR6VQHmwK`9 z3eWLLPfce}dtSUP{PRPuLcsQOet#mzjUVq*XF}(5DWBj@#gR*BAAa5_od6%b2zmL` z8ra*^S+L-xp2P3-543FkP+j6k6AD!=5+1U_>6xCIure`Mj$wh9ZFMGZJv)3}kE~l- zoO#y;}usi4d;qGknF`D-FOP6}X0{8ASN3D=xr3qqpI{FsN^+0~i zf9z@3zCACyB^;0!VjxnG%q58x4~U#fw}zQ^e7l_8Mw!drnfhC_yDCHK&S>9G7FcrlTLQ)sqRy@8V+dvf>(NM36wVtK zj(z+l$~Ki<#KBu}m5X*zJ%W)w3R;@E%6YgEgOA}Y%7?NzibSKfr$ft@TLCvhNd+o) zt)Pr(w*GnHx!`vP|2dyy_!8ms`!-*l^^4#}tvg5ObYU1XVQ1!D=_?TqZ0)tc!Cy-Q z#;-kurn?P;pi=OXnH}L!2o6?fZKkl6z6tFj1H>jp!FMP9{`KQ`&1bcv6DXkHKf81; zTUn!2TCI``vgU($M6}-N`t{CzGoFgdAg8W+fy0Ll^GP0mjT^d%v#7bbS@rkXE$Uae zENAT#yh=bm&fq(B%-7`{+mV-ZMzf66a?sFa2udn3{HnMXd@ObqouC-@aPxQaQ8WT0 z^Xp`gOLg`u6@FE~g3VTsy~x!?e{6bHn;tx-Z)`5v(~(xd&j@ z{i_vsGajE;SdPx2mzTj{Xh@S1MrzWy9Rj0x-KB|nXHs1DGH$3qp1=@=?N{%j^kN1t zLA{f`X5HEI3_27JX3G~RMp|;zZRg5||K=}W`rN5LhUU`2nxR)q|^RQX*@l+c#+nLvO6vW+C=8CDMQlk=f^%-`|iMo z>Q&D(P6of9oa)}~m;H!MUmEfXQ$Y>oF-K2(H2DgK&!%0bkJJbaH{R}w6|slKdmQZ2 zAR}+a6t4A8?97NjxQ0g``|cB%JiqBY2=$&1zumvBd$%pz33}%@oCw9D1G6@Z3-wGZ zyEIs1GJX0S%Ot_@sMu5wFpZ@`Q{Gs58OjP*e-n@CqZbStd|NZ`?6ReLQ4#x#iEM_l z=7q}Ln3t4h@{m6$%ef>Ni^z)-}#ab4uxYwpD} z*u-yDV?WgilSplIe4 z?nFIQVq@)Azhe1EuYxnT*BZb4emgtc_I>!|^~^cy&E3~=iHqjzOmo2(~#XRlYgXu=T8=*!z8VHG>~o%-%hjRi8|Wa`>- zyl=3hmEErD=I^Am5Kk`{R#D7h?SS$3f{)RWggN#6w-j(bcKd|;fi^bTJw@c2tN}O< z)2edg2m+5i%w)?NN1O_09nR1#wrALG0dY|wwpsMh(eSXBO5niN6lXL-jNLj_4=kl> z+ncrc&A_p(DVoL5i69VMc5fXVJ7ALdRZ+SBHC7PxjH@8G)A7e(=P<`RAqi$#k`c@T zw#sYIj;M4wSNh%hNVT2P=*;yudN|+D&({Q&NeA({q@F92lG22WLXNP}zu}>_MBbPc zkJ&^cQ^zzB1I34P{A#~_)@2!P(=;dlZXf=$&z;Zq2j^Y0e7X9nW=>w7L4MJ{FF@kp zi!ALtkJcXO^Qan!CFvO`Hzu6zawdFDZ1|e-qw}8#=FTo{HJ2gw#}0Hhvde#w?E)NI zt!}I_BlvV>c(+U|hbM*AdimpzZZNO5dh|8(a}ML#z}4+G+en z)uA%WIqmM?Rbl$0-m9)sauP_X^9Rsu!gyMb8)iiMsxg z%6ZV?=cd5`CS4d}qMMYr=N_TlN$O_6K-ws^gdF-6xPjg>Oze2EE^O5lG(DV2^QS|u zT+Hd<J+cfue6-Hf3W)Aw`~rGk{Pxc3TiW2TeF5xaS`eO}R!{WBcOV&_a@{6Qh< z6JrK0i89%Qx<>X9J&kw1V9}5?V7ja#B5UswykKNSI+MtAkJrhlJy}fq-~?$Rv1HB2 zX3k;P!(Z1Cwoy~_Sd_^{q9FDQ^?^*RrzXu@kUaW9Et@8ko_%nR5?wmeOp#VMd~gC0 zgHKS0Gc0;$^e=QxrjKvHygI$EeBO0^dA6HjrHk;Lm#U1;UO(lbdvWt8gBF8-MHKxp z0=K%tpR z3{pB_-0yV1)ep#8+|w%lYDzy-1GUB7JMSI*@OZrajBkxTUXVpR2>^(_GmE#KN?b7N z@jz>xfUXGCc#El^PUr6VT$a(T2M@*oIGr@@h5gul{DIC#R?K>E|Nd5(Spi4b=Hz(r zQM^>=5P;)&h~>ylZT~a9vi-=Bd0Pul#Mf1Wq!~AAltQZILF<;G{+OsvM4I zU+cB{%LZy&pk0r4d^V9iNaq2ZevQfI|xo7<6L$=bqg%4bRpjtO$jMz$anwlcT=@NXEW8 za~`fm+#pD&+c^>=c6_`w?Fej->z({X&dn2jJ5JK-)2AtFy3ribnGqKK!&Fkl@x`g4 zE;Q6Kmj!YEF%RISq|ycjY}&Leob)5|B+gLyfP#a2_bSes&F#Flu3N7`eI~1s=ClG% zTsgDvSjpf)VUfE6F1RW!4zNYDAbxVP7#A67Pt%hPwOTcTDLfIIk#;&EigI0)xIF!s z#aXSwj@yq8*nBBuYRBZ$*~?4s45XmE^m)|bW@$$?569fm>O_k(V9D9c_R( z_2eM)@f(y3uZxOswJ1-oURC`uJbDOlz&0YlM~AUyBTZsExju} z|Nmj{&EtA(yY|tmiG&8qSW3}^!d+CNK}4dWNuvrWqCrT5M9NSJNobNZ(O`}yDXC~s zX(UBLqe}DeJ68Ao4Da6ikNw$y?LXc=@B2LW{d^j}*Lj`iI@dbZaUAP=25Np5s`(wk z*(K8oEUwmrjlF*zvhcz>HKFDtXNjG$rX@w^CNuZGgFRQ4TrDKRINbb#bW*SX4Ca{| zbvTg3mISnd_@-xa9Mo)1pD4ME?T2CN!>A@yjaG}R3~e+8I!P-CFmt9FFKuZW$p{XT z*6dVUTXaIvbx8|3bPR1u+~O;7Kg{1VuK(@TK4e+}?oaur;Neh4Hv^2+Y6JQI_wb*C zzH{n0)Vz+M$)WP=GYF+p2I_C8XAk|HmMDcd5Swo6eZOfl=(TYVpBzQckt$4Pw?GV# zzH`r>JyR(4hp!>oJj|UNs_f?;B0Pb=td*In-{R?t5(lZk6ufX(+7>)yv5CtVGiz~JB@ zF>}C=)QRZ53Tv`G(6v(ji+X?{4ERY|fqX&78W`VD;ThqcxjPL6@_Jf&8uK#nB73fu zG;!{2neTSc4WA=`wG92bd=%urfux-uibF;XW{Vs~p!lN>Z&@caBrtn6o9Z7k{~q_0 zpPtu`EEB%2Yo4Qrk{31}__p}ePE-S2_24L>xF;vn1#!^`Sae#k9mt!S6hjN9s$4|- z2e=>cd93pT5JxCJpttyKc-^jfEgiSNy_#y^G8jMXU|0a>5G;uB*!}H5G*HkX88KL+ zH4PacWOg(Y4^ePq%PLeE?|EM?O=B`5JU9W$p?^DzAZ-Nzc@Y zj4bW8v2zzj=42k_qF`cd9KRwoE;cq{xhYzqiS2eFnNx!surtW72vv;~M;ENoHf$Bk zZy-Y}JOUvKXnN|e(SEN}yVl)JHV(lps8i^sWo^aPN7aLXky%E&)`>J48e|!{t_So+ z4j<5AJ-Nx2x{|-YS9~o>8NCII)yEgE=;D|IJO>_$mxNU;tn4{N+Jw{=x@C67yP%&A zgPiR@pDjTK)6$PFXwcj&>~giedQF8kHn zs62E2ENcK&SsGAmVj<(Ya*saMXpwOux-vBlK^TjmapJ`D500e-k>Y>J^8j2jjRdtP z`?e%lH6p_$LT7BkipLOMCw15N;@>8lst?ymAV=_hv1;geGAewo>S=rB54_)f9K2F< zUkCpz4&Sv4mCJ3Hq!_=eOnNQu@W$T?`c~?NN3TSmyKeeQx0a=DU1M!Q=?U&k}G}=U8y6U;n zB^3T>C@|lWjJ}0Rstk!`JC`^~szi>)u1WbgjDjDu_8fH1(B@1@?L?_Wp0#iTyKL+0 zspbs#74&?mx(RuZ%dVzsX4YW4El3GJ225__-o1A}DC+@TER~pc?d8jtGz9}#lF@Fv zE5X6PH@*9F^C|)d?w*hZVjuwMBs5!Ab(@hb8Z=xdPrgv7zKcA7?j}rz4rzV2;gUYE zQ<2~;-N==t-Z_N+Jp<;z(V4JigcUw11aY^Y#+@fQM#FTmWD3VS1sQ*qYrrie1($(@ zmGG=C&SS>`vXX(uiUh8ULy(D>BS(b`Kp?O{WU)`O*(_OcZBo$3909$0Rq>ZNvR_ey z%u{r}tMFJ4M6*>Geu1KH4^rJor9yN}eE`6rc8}HHIttvQ8JN`qBT3wJ^HDWd{?OTQ z>LDBiew+r`nXKNVO}lId(8{1KhG{Q3fbsHMK)g7Mq6)32-&@al6>t6AJc8_wtC&NL zH`MNAq%b%MXmr7TdUH#ejWcRD;q{z8PIoQPMoL;h5wH zpSzteJrHvVCISk;HFNhju_rB~p4m`F;9h?qcf}N{x?o%peOWS%*&>==Y!4%z0kZ*Y z_!=zN%dP#y+Xnp&?RWm~=E>L}D71)KEj79Z#?{cv5iM(E`K~>@o-*xCTP(w=>x0iE zU5$gfc^`)kT8`8iM$ozgHVy#T0;G@RWeE67>(F-N%g8kUX^g02XgUN<>L~2M`;0ia z^qMrO*YW?l_q69De}R1!3&@IjI3K!{S~rQ2iVfwT)cdr4I+_3?8Ikb;>{rg|)7dWn z`Fh(^eK>>IJfIPD_w*D!c~S{_qd)kIG+UMUgCK*-SY};<_muy}!!7*a3nTqtdn=uT zi<<1io*(a81b?%UO#{vq8fxEzZj;s=e6tvUSs(+GD1_vX$eD6sj2M%fF#4A%1{;&F z82PgS{`JA^Vc_=y(tavaaW1qGIEY?|xIjS-ngJpeVx7sk8~;mio33;s$h})T%BQ$-Xyi`&|!prjE7&=LXhnA!R<$d zyb5M;X6{esoP!pgCT-tPbP)%Ceg@QW4X`H1C)|0m{1ZsbggSx?kqe3oB0YJ=AvYm< z%t}2oT8ieu_l*dVfFRTD^i?29gYIyBv>-rLA0w=a+<zeh3ArBi4CU-Ax7ZzSrV+vbkEyOj z3F(K)cpew-vcBb#Lk%|*XAf?EF^6-8VkYj`yvExbs;A$F7tH3;CNQFC zx(k?mH(f1LJelG)HULv#2wu0lPCH$bnZeio78D3$CP zhrr9|{s9buEdaF)y{4&Zygf7O7A#!&H7?RC7omyNY@})bhpz~t$4t@>^uTm7rW*s2 zM*HiJz*{?P9^(nJCkV&+Lm4iZBXAb-Zcw|uDL)3jOgJFc5{=#S6z>3?!=Th#gp+P~ znp2MAoA^%c@&?HF;U>QjT8KPyA%LVFQaMpf4@3B(A~z9`rX;QlBQ)E z|3iXH%hj!zGY;e;AA6T?)~q|n=IwXeEiZxWoA5MnNVx=dD8fd%^@T+V#+QCwT|#U> z&^Hi9$U_gilokluV@}W)K?>Ejiv;tC?d^;2xtOPW35U!ISz_}5tm2Il_!{DnMG!=a zqVzS4Tm^^aQqEoBSnns+o&B%{t4;K~Ym$1bI7imCt8H~N5UOBxK_qNC*3bmX z-myAuh9;QPV7C=@>5$9#!PX2?iG=c%-y=;>UeJ@H+F6XY3_^1O+AUP+r)is^?FwZQ zKVl2Haqh9T^@HBy6*wKBjp-}iyr@LonhL;&xX9UEzP}?tgG6W8=Hk_S+@!ZfQf`Cr zk9C!?gFL+V_*VX#Iqh#4qiZ`TbS3=a?QUH{DC7>+?MRsh2HnRILr><-Bp zJwBiptGNFl;ux{KkO|;6(F8a~m0DamkLzG_#cSgNCYV~0-5;o3^ca(U4y)yTEIk&T2a>8`-LQ$K1(jrGDdO?d5AFN zGkm6Nft)Wwtny1q>5@od!j#=sgLUYG-O`LuK*>D99SoalJ>GAxghYfSL! z=?3XKs5aD-Eb*dn!C1rk0lvVG-q?*^5zrrhTL?WaAfD%Lz2bzMf(A!h8n2Lux;=;H z^X#3ulnkngb4742h;-!Wgq;d9yVsQ=m7bko=95hh%x5T@1qL?>Vbs)+WGH=#XXS2i zp6guerU-Um3=J5F#4H^3BFz_gB(JbbdgN1r?vc<1H~Ai9nglVR2&M5y-+g7B;M)bWm}CHyZ~piZ|5R{kYiB1n z_7z=g@FyWB1tJ#G&O@RJAwV92CD1%>U}c!PZE-tE@+G*7(O~t_tFW`}_N1aYEHK0#O4D_ovK0qm=^}5AN22q6G z0Wh^-pufODC?g?`G!kip5zRt5R|1sSF&I+-E5GwsZLV_yQ2L(wJ7{!$*6b`$3VUQMJbuOG7>&u7m zR-kL`S+oJ&%w0pDkjRDH(1h~v0hIe_VE7-F`?|fvNq;>-_B^ot2B+dRs5#>m_G?0AJV3V(ap@9tfgoVuq zRiYdTs;ubcP#QZ%Gdc+#!k$Ni5vXG6+~5+_nl@mzQ%(rBXWpWpwZ&tAdU>FHL$8AvyXr@ll))qqxal*X%TK7jQ?X6)d{ zqAg)K|ET=x)w!fZfrbov!98djfLVd7X%pfH*g-FqE3V%ft$aXz0I)VGfI_pu0{|Jm z5#X`r*3~UQQcmWMivu6)pas%mIMqY6keNR=KI@x{@fKG@|g%M_|$ z--jFb3oxZz{3ZQ<@9PyyR)*d}E=~Audx$hA>Fdj2icKEu`#d`^;fMx45=j_xFWOMp zSLNBKPQ`wS!*M|5xPa~@RaNvsF(K(UO8Lawt#U|CK-`HLx)OlHi8F|jZZTN*n7Bd> zFi?!*AHLoQwt^#)fc8%>=l<=kqmJMH8+B^DvL);t$bJLcb(dgp0-D(XiJR8;%4}aE z!8bnrFU8q>`+9K$6Xgy^eX#kYXZbTy+u(8{UjHia`*U}2m(R7s0z($0ne{zeRP;k0 z-WNGR%@9=J#g?l)5J|F7oO3v;ZaZDNZ#0%(D!%`B{gLcwoOEH=`#R=tt_x?QoETbE z)odyceF}*F-yhBRv%N2txFiS|&_CVH-BENZ7ZLe4m901!k*+U7j(mbn1Ctuh>hzdO zUoCVF&;D3aB2#L|_bqcu%j+9AZ`|+*&g?~Z6$GdYXrg<%meO2HbR)t0p;d%cb^>=@L>c_Q%sd!3w^96gp8fk#&)?0CExtnYoE-w$h z%quQjfUnWIM_*UBW^3UygltNzeZyynK3Mf-i-v}N4y7YV`bnru?fwi7U}N5)p+;5` zLH_!-v)uq@uRAt+_kitOaJH-I5WkVL82|2`lnnUnXBw>;LOt zH8k$ZPZ0z~^*>)*4$e@R1=`2Izk)x`b6f(b?mu6Duk(NX8WWqG;Q#yv{L|w%-eUj$ z1@sSg0W8=m|Mm4@!Q^uM=a>BN2l?L>`QJVAzo+JZ?!a5~nzvlII-~;DC{N{)R9Z8kZK~$G8 zri}1VkhK zK-Tbnjl}l3o*RH;76APzN6sRf7+2YP1z0F?Gc zLxCz0fYipkGrsBwcvsiLlEAnHxbfBm`z>jz6Cm;4!n6=Oa95g-Pfv!he>988e+o$QtLIK%`F#QyaHaIHo5+>$?~yv? zgM;Xy5HtdH6T<*2^bz6YI62d;7Jiu^{*W=+k6#a0%qlNX0zM7E=u1lrho|RQkdh(H z1@ZgHA_7|eXz!hrE%r4wL(c&`o*-;EP(D*n<$HTlHfdKJIW%726Wh$clb$c@Ffzq* z+Yk^vkTugefp&{L8}SJMrrdWPK`96;1LCc~3WLlse9#!lA)}-NrFgI?KxZ`2ELe0C zX2ehfLJ(eYEop45tmx1?3yI2&3>hc_;{uGw0nP)VFeuE)(Z_q@eQjucgS)Vc3g{cZ zf&9im)Jxnn%pq6_-$9Hb+S!uj7OPn`cDZ!=AlykAvWejTE>ii}HfdA=Ai*Z^eF{vj`7`^3oUDRCw(rKb*B;WYmu#r%))zUm=9- zOwf;UTTXL#a6*~43k(gfM(-yC0}euqOFNwHF$Rzk!480f&d}!x8*<8G2wMqhvweRv zq_y>d%&C2i11{eOfR7QzP1GzL4mNZ)ZiIzJldUvp{N#_&Z^EudKrY76JCJ+^8v|Gr zkfw|}&u_HXK<=Qbht>ha@j$HqviEE;T6j?Snvh#F$_=7WprRm?9QYh29Lt6pfI5^k zq8StqYC>QrhX6u#QSa?U^dp?!#seOedcXeWF0aT;H$}%U4A{Oj4gh;v>i$SgE2##` zpE{VLZ%S%oFeV|*3#k)11wj^yM-B|9=a<^a4)6o!L8PG#KoE}qr1Qx!kwg~Qwz}}* zM^l*~cNA;%rii5pY+~zAW8l~T?dRNyi2+COcgA61Dku`DOU01)CZHJk&nY;7aC5mf z*mZl~*V|nob=Q8Xe9zPjCqN{}n`LsQwp8p0m^{q2;^JQ)LM~tSh8;i@gl;S4%Y_;0W4K ztY=lM#Iq5|xR#-prmJm)GSe3U@xmz(mtNMnhlYdF!Ki{GKNpydC_@nLV){&YqpS?V zgeE>$WYHK{hY5~HP#;U9&0qo^1fCv`3Cf61?gmJP=BPyCS3P_>$NhF=WIyh%-{TVE;)xBLBWy?h zw8ix+SekeVt9WIcXh$Dy8q_iB(4*y#5snOn!H$+EmC(g@d_z|-~I zBO5c_C4f5lx8b&uB#slh1Om7x&H|af*vMuu0%IoQhG*!zsRKylK;whOn<%KDwrkpD zqXQ6)r)l=W>zW0<`wHw|PV`fAfjRYhl(bX=Uwc(pxHTiYHoKQD?;68fwnzFs!Ye@Y zW>3(Urthcln)@Qillz3b`*cJSNao`~QILjRu74Zif6BEvaQ;z4Xk6pEJ zqVZahQ?1ERJL;l+#o~=Pe-@|A&Lb1|&b~c>5guP}^p1gd6nQT-!3m=I1zLO208Ox$ z&~?@tDVqMg$H$?6?iqCX)v$f=Hmt3VBlW_J%bU;N<9=tyim-U^wEAPi8VC*tKl)J| zOfHz6zJ!0a5nzqB7h<)B$n&q9^P035uMvNNI3TaV-$5z{=8rHSlKV zt+{S1(NEVww9w~V0vv1_bSC@^C#>vp(TWu4&%|8$DCntvw6#q~I%f2|8u0XNoJ4je z7dt#Y)v|xC1iEHcU?Fl1XhfWHuE1P97H_%<9we`OnZX)}D||M?iRBQ$3Vx9HNc8?n zPs6A~1yy-?7cWaeX?P>xz7>FHC0Cs4Zx0(xM^>W$^ERBLW-*-Tf#*U$me+z-cP+}Cd=$uYiP{n@O4I`W zo;<*id0YYDq)N7%0?A_WCL5;MR`N6QMS3L7u?u80P+^ucyl^=cxDLUQ9wUMdp6Oy< zDNdj6{ti=W#pi+qc`ZSoNGyPt0XFA$Zb5muT@32M2Ofzfi0q{zbBjr3x2t$jDevv8+AY zCJrYRMw0(~3(6zWIm0}H3Q?(@+c*_ib!$lry9-&a52HFy-u_zd3q)dEDDVMGRPM+A zWC!?2F&A;k_qp=O#R{|zo@gCMgCK8CxbQafoK^UF@=vOHZ3j@Y`9&$8<88QJFsjHb zp42&TiGCSO?{z9lLnvMzm~j^`R3hxjb~>-n%f_UeQzIDdsN?tcgZ)3xF`1R8|-H}RzB61<&5mijEiCoMCg+~^9cC- z>;VILli&;^JM3`&>wbQI6j~9$W@EAMtTyin-Upp@&vPvlXjdG$HAROH0V*eMYitWO zSCEf_sHl*lPSOlr#cFFP9DI1!TgB(`(=L z1blA9jLTtRQr~@VS9aVAh5CT;hQ^}k5B@{T#R@qAE{^#1-*L#wFa;3=BqH3=1OWJ3 zu!m`%Lp@7g=_auTrl2n{-e??qedy)i?4`413uTW)Iwhk_^k5HE#`D3(3?)D?h_KN7 zIgJ1gzi4;4wV?fBYsGEk5Zw5U$fAf*hpixi+22eKdSN_ZRD{$ced1t5)o=aBGO_-JEUUm4Utp#QZI7H(;p8JnVU#m_AusNge^KMjf3@ufe9NC=6H` zYza2pmDCI%4{!7bqMR){G7gyDEA8CoS$YInH5hHzuA}f<$RMdYkI*Oe;5hDvD6^7r zWX>X23sD#?y&mjxr7;P|9tOc)`ori;^p>w^XXhzid(UBLIw6wgbgrJxQV5oNkR7l8DXAxh%`!= zO7mrxT~jcIFbE5f_~~9MdD}KLn6)Sk*^Z*M_>X%I?lN`;=?kp}A@jx56!2pvSUk7t2xF)AB`b$()eTFl+;KDG5iXPlbVCji~LSq+V=m6yd z+*Q>+_;jcDj-X*#jz-+?mx(*HwJ!}$fZ4<_=;jlBvSUa;xXTn7>Xs!}%qnhPRgOde z<;B+3Pd|dBipf`&QLzQWV0Gxw?UX*R`*`(h4t|`0#XbYEHF}Ttdxms%OIcjqQjtP4 zQP(EChFa!Qw+m!~0&K7{xP22#7g|jFWt+(Z9^6A3PXo;nR~u1_aC_Lmf&e}w>u_u6 zfu};&9Ec+N?&jXj1An2j2c<`$t2?zBc0Rz3MQu+6Q$ASzgN8}AzG$Ek(*`L3i3VW1 zBxhNGTqp8V=x~H3YLJj(3{PZX5D5!iJxd1chB9BkahNnh&X#Jt%-#@B8*A%48+5lq zsiQG`sM5fnZTEUV?+F2E`AEd{AmxFI&B%Oo2LxA3|HbM4t2o<^DG8?!0Va#ta}i2aTg0QQ8-l8;3iFCYAHH` z3CXp#z#=e8JyQppc^j zM8QIAQ)Z}ocqrpUd!rs%i1<4>hZtsunjL||E#w(S2X!twmRjHVqsSmE=gns zkCaPOl09@ZG%Vq@rlnyi@0O#3NK$td)6a~!>e(iijV4ZaW^@&=Y4tbByDV!zRW$cB z&e*qf8QYa|DQ#8zYi^8wSn?@n;)6_fHT4iL%Xc{@pF(jeUDj9R$IOI*Bd4~fOs!vY zaQ3wXG0T+0gNt0h$tKsj54YUyI364IZEg0*@USRarYK?+AhZB^heRH$Y9u**ibkYV zHXkmXBVaYD9fsoc-c!nLgRyw`;ws4zGL}qO`90hQp!kC8*YRUV~@#VQ5+~g)m z+!}0SusBx1jRPKrQvlcaf>2tcWGeH_TZ{qyY|_FSjz~0~3A7$+uj3MfgTyj;Z0UDO z@F>*P)Wkn{P6(@L2Z)oe>+4MqT{?#1!J{Ju%ev0tb8q(M9(Q@0H!}h{LlS?wqSh_W zLub#k3UiOAn}51rQupiJLukGb{AHQFGmlG39Pte`-Ity*u>BZZC#{1bZil>gM&(!X zu=1pU99H12LH>R1M8ukuRo3NMl=_ z0m%@hKRg9LmwI`j1_Qy7VL%WLZhPAEM^R-UgRC7?JKdyHGG+rnM11rJU#rq{tFdMM z&QA>udqC!YtOKi$A}i;X*CyhaP7f4dNb90UBL&+Bs=Ciolp`mAHb_?|iK(>Wgj7Bq ze&6xAMLeZzv_RSYkqdMUSVYW_@MG0XU`wzerA^H?xc6(fTU7wQ@pZ+!XS*z~ zqrL+ijHhR)7V$S{u@A zrXud%Ki232G{O%6ixhhfNd?IH6O1C75rYhhzdv}w9q${ern#Yzmh9>N@`~u3b-jfW z=T*S6)r=bfJ@Yl32*hNz?bz`K$v#G0k@%C$aI7MM(W%10oySn}yyaY02RQd=__1YzK})8WKKMl?hOE_JCInDlpBT(_)z%se&}}#Kib| zvYI(mljWqo239B=VBW)xW6lORGQENk*+WXkU$E2h5%jP%`;$G zz+0X00c3od{ekZ%o_lUalrxaNT#yT+%VJ6}M;_@VD?@(+y2m-y=tvN?&V;99`9z6j z=1GHnM|4$3Y*VCBrfvjL3(rHonxoyXd73ls=Cq#Th80m1B&hI?Kdat$jDe*LxvhKO zw|MOAHZ!gcc@{Xjx(T_h_>k5c-g_yhqmm-6h3fc1e!2qnt|RrODRan|IPiLtQ+L2) z27D*b49|*ik{|>D|M_Qd-EV8Ia}Qg$&Y`%k)~+`qa-mz;IJj&l#)hA~8SzB1W$+qk z0MQE}*g(#sZ+Gnm__o#nwQyIfEsni{5=->|0mH=BO8E2!7$YNs6SQo3S9MAxqY6>S z{^V9H>v`12zzJewV_~W*j=1b^*N%ZBC|(jCH{y=v0OJw(CA5KL?eRB^gt3eKfP`cF zM+tulitN1&irnm!p8(DN!6;!nVu-GqI^;QLk$4fR0iY9dH4vb7DPdgo3u>nQ4IL5Y z2XJH-aBEUrTLC@)8x*x`B%{BKj{v2ngWWa(1^JG3eSYA43N_9v@M9SfphR5IjPs$)lhBXXRU@i+#r03Xdkkc%H^b9{W@xMm zD(tdE+y(K;$TtI`fqOk!Z&fNXvSbr!tuu5$fud@jJ~a_Ko83Cv+s{Tti9AT|LZE1F zjVS>SR4Sr4H&>l7IPly686hdE{9c<|)CKJzUxg_|n6Ij;49Gy0;RILFzKA;Qq85#LJ;2&@(0H`SzMx5aVJD1+ss!zDBXR$9H=yA6A zv#PeS*PyVY2HHH{Hq&oivrt+Sa7-##C%400!GI}Si1s!tj`=3sAbP8#QGzY*jhkHr z6$trEA}>b>Jm%Ch(2fD^XOJ6_{W+=2aiYoHk)~Oq_a)FOxoaN=2~w98m)>a*?8v49 z1rMx0{-QD5{ah$tmUXv1L01j=1QFX%TGKZHE}y61D1yh;g}NfvlQ{Pw~&piumUSht}m4|6xY*n!#37c$L#TiVnE1=e?)_BcZ9IBZ4!FfixK)O{Zl1#rhzdF} zHd4@`ZK;bZn5@}gM}l2)1|}xxezU+C57UFq+Ls#X);)E&h-}OnBWL4xKN2Uo1|Rt5(+9+JYb`1Kkx*& z3KCPgUWf*dA}tRRjdHY9As_L7{Up&EWB4%djp+C^pAuyYnJzV8!WKdSIkyp|7<*-h zl?uWkk?ZkpCMgL}*V8Z*q>$uaffy6_O309uNSOE+^|u|DDtNv~+T$vx8`j&VC?O{z z1*_g2jcN3tU{edY&+{s2C2lqjf-fws?|2OL(7X}I>1gT$;2s+!b<`qM&<=q*5Vv!nHR&!AWYdi_CI{+coU2?=cqzppBRga=#yEu0_pbW;_h^c8 zDh}VD%~!5@5kr+8JXnF$3Vq=GCZi38co8J_rdhL8sa>}5;>JCk@6E#YrL!{-1P`!p z6H<|$5knWZd@iEz?wUfD?myp;4|n~9Q9g1GH25V-ew*?xQt`Tko^=#3D@d#gjk>QB z9+}1M(X|+~6pn){{QWnu4BeiAmM6!YLraFQ;coLK01jCSu$krU-*`Rc{|8|pLzRYk z9U?OXG#zHF`G;#+K6m?=@o5c>Ye~pR$VU{tL)=I#ghE;4*&+U0<9**s@_g15&^c(1 z7PK_vaSo_R;Mf)aEH11dy6m#uU-$3b8-<>uapp;R>}ta=1%#T@m-H5*cG^kd7@`4o z8tOuH1E&H%XuyCfsJ$3(w5(YWoNN}yUD@J-XG$+=*{Ts4gc1yt9mASJk-vYD=m$<> zZ<_qV)JyoTP$;2sVE`yG;?*01dn(oO8dXJu`H!vsFM8<)sq-S7>Q)jyw2J@!%G&WE zcLzgJ{oh}nng1V1r~U8QX)?hoh|b7n4>vM2w@DAfr#}FmTVV{Ni zkysTBgB*tJ^r71#&&ZF}rqHv(nEehATNB|klS{uP9N+m0TMZw`H~iYfjI7ql^)HX> zlWOeE3iFPEu>;1<&;;!3w%AlCEQkRGw@7OCGdWPRkP2lK4GrXxr=O3Zlp$(7`SU|r z!52C*G7n)dLyUtsPMyFd|39>R*a6sP!UN%At4q90)0*7q%30J2Sp`t-(`~E1rH#@M)txw*@V)AF_XMMFh*WxU_<(= zdmrgq-*s!K>bh2})~n4TZD)RAG9-ZnuQp)n6SXCqb`M5qTK@HYDLvrLrgahfjb-AnW*S4m_3IXOcVxDD_FA0) zx{q|0&AUcR4nCb2T%du-1rYasY9{J1@~^=rSqu$^*&<$Ek>uh+q`-_LMp^oHu;gI4 zQcTBx9dAU&&5hbTn(4>@X%qN}5buQsN}f2OKrsPTBJ-9DO|F|RFk(x0rpa13IbXhg zJByP*UoL!3HOt)V zS?NJxgZiYHWcP#cYL-S@|5*B%pxjp3;CnYB<3(1f*dc%AgL-9f^SvYEOF)^KzHVCI zOe?QDqgPa74BqGK&&zfN_HldN#lNGIT+jQN3HH|3tZKX{YtDmz;I!clVxj2b#nGu( zG45BprQOhz6HY@XOy`Oc-&|Kv5RIo!t~46+PDT_(9iKjZn%+IxrTHhsXC21*M8854kQ7ewQ*;*!;MM6-=U{>OXVL|lBunfe$Lb`DGLvEYx;QKyD2(7uv@#Q*z zaK;Ld;{ zihJkvt5<4m^rYfnY5=r=j5iCV61h&G&&Y{D1vEMe8t_5z&IC5dVY5*wf^;nxDJ*<~ zc*r;hr3^CRvv}|iibTr7wQPmpL~sbga9O&GEp-!M%W&yKKhpnmIB#%7oQkr7MD*zL zAvSMB!K3nX>|c2co^tpmI`Z(jB45!NzxeFnwv~&^9z`@ZD+b3~N5{6@nx=q#ms?dO zZ#jY!M~Ry8(?emK>mdb&70d!OZa)qRB01fRiG;l&GXJ%B6ARRM(N3LzXi1R-)RpL# zsfYLW75}_QF8=@dvK0D%CA+y(y5sFNwddo!55KOxd`wYL{T{;q-@m1nGi>!;b8gLF zO-aiiO=6T<$PLawJ&9Dn2+F7`KnRpPLj&qx>;+{%ZBL6kS!#c?Hh;~%bVGcX$ijy# zxhd|v)h?^7Z)&S0S~^-u1WcaV?(Eg>W~KRC=4A%&JBA0m!+H!+AHPFt6up8N;8HND zSPhs#;OVKJZ|6Vs*G?<~rQ+~QBmgwG3oW$M_%H44}$|WIHap& z=bd*MQ0rITY$opd*yneOd9YlUZ0v>U!6G6~onHLE51VRA#dXjM=?6=@X#aA_?ask{g4fxBbYc zwZ7{G&&R8Nt6gBS@!_)~p2L&Jb+}J%#l=UHuf+0}_qJDqu1zMfIUn1et)hQ=H_{>@ z=g>Xm7of9a$DG!K+Sm}-FEo<{p*I@4U#mXu?p?|ET?Y;v;Iy&%>JhmP91c`Kv@oEY zp8*a?|)FnwSk1 z9?;%tWXgk=_I;$mavh>qyy}4u3ruG3H0@>GGIO$UXbHWy&*^$Cj^>YpSHh%TUs$mL ziL{~`Z+Gi;&jV)9zS|00xPG|1^@g0?mE6^sAB5`%C8DH5Hf~$8O|_~eD$~u27}k1U zZyDYo7o5Xr$}*t?Xo8Y&2~taF?bTy!mWqpu!g3Z>5aSKGJ~y884f+(081l=|Py(vk ze*7AjT`mtB6!2ZFvDxTV5CDg4nuaw~YXd{vDkEHgk$BRs4WqjM72+(~(v7+9ggmM&UvSd3=v1H6WayQg==*ieud-aa`rO&c z*sf$8<{Bsr zYBvtpm*!?4GT4R*#)6q0O4z-tYB(7r=e|_C!Fd%~5bgps%ZXD1Yp~6+@sR|8LeGI; zL0u#w)h}3}j_wXJEt(o$XK?{}7Wwbe&h@p3L<~bqKoA4jVyZmRUV8;xO5nOYC>^l- zn;c7v3xBbU3^$okKubd7Jcmr71Yl^O8a34%^ycbo|CzqIxxtewWZF*6WjD0jxY<%h zd4!{!mf21lA9#5`)r|G3-KWRgCnhuB@`B(WTdY<;SU1bdB<^G)`{Z7Y=gkZ}={Gl} z(d*&YR)K4y{4_%e;l_Ugt|#kJL--@1vI2P?Ed1o$e((TvYsPKxaI8$-V>_`a3s2uq z7lYX8e7Z}k=qle_9MALhXixLNl(6}BsG@+Zg>u)eD4)#C^)&D_^&K)g2B6(^ zk)-4C(5@hSq#-he6u6D_SF}{g3cyTuyY$-#cgw$A)$&Ch%!3~`-QBs=r`=84uqC(7 zYKR;9YAK$CGXi^MosZ0P=s9S#q-10|ZioxTajq58Wu!HRR=FT2X-(>}d7pLaITkCa znU$U)@u7!3}}d} z-2+HwK1u@`&Po;@xTz>f0S6(~I-l$9-qlX^P?SVrCI|`j31FlqJ4muL<^Yw}6b2g$ zCSb2AtMBnpJag{m2M~&1?pU-WB+eB#5zS&Qa)7VPy{e(?{jix?8yDr?*-{Y7{{x%D zuy*zCe&s1rZ94Iddc2nvj;mZaQTfiOZmH7FA63!;Jik~2CmI*j;XyAvPWF{d*`1Ra z3d)>>?MD0ipLqTTr399%#q01I~Uwk7SNt4N#y!R~+3@78@3}`lyi7 z{FKJTE=j8+2Ls01T3Z$UY>@qXH=NliybMj*oQZSFWBO+1<}3_(rkhAu?zgdtDNh68 z)(^e9Wy@b;ZZ=-7XuoO$jfsM53WC4XJDs0!&Fr~>!!i%}D{OD+QbU-Vl0DYHkOB-Q zn4)7g2X)!9)J2g)@W|OsnvE^SuHS}QvVD$B@7o?cwbr>rQM4;P^w-?ct))U0d%2oV z%52X45#ZmJK5y0bdmbNL62i3p+O(qk2t{igt^FzUSRSo4CmMesW zq`bz{{M0WmTLx}xL7ox)Yk?{17=z7wh9B@hZD<*-dO*O z(I!7(Qi;IPYWKdIXnDw?RRwxfIhbZ(2TcV(PODADV^D;fZ^SZpQDEQuw|~98x36^l z_HQ3=>DDjSQrUd6%riFXpjY8%`*GE4OP7d^w{G`d-ZmUy+vjbuFRsD;X@wi}oH zQld@OUKR$N_vgCR-nU}!@}a3G=8t6i=C6oem&JB>MUsF;u74IQ(m?gMVw~k&i__Nl z-ZiwFBJZEYf3nu3!*=wU$$sRD<@&-15FmJv}Qjue|l&> z`36xt02e8_Qo$gb{C&ydVrRnsE?fLsy7oajxQH~<26SsQtaRsfKE}yJ$5&kGSi3*S z<|5!Sfo-mi14?#MPA|b1jRE0=e4ikCq7D+}q%d`&A-r(E;bi9hwSNX(87SVG5C~AE zvj9>g>>B-1esr_>Q$}_XS`B=)JW7qNudSMgI1J9ttC6YEeSWC^=}b_gM)$T(%!0>5 zFL{_b$(-F>87;MJ6@0MR}QDe z*3H@IJj?c=Q(o^ozmyG1nY9&tOJ~@b)|`<2;UaZu&#O@WsLSD5zL%!ekFWJT(@<=@ zSG#7)r~I*bN0;n8^XlXsS=0VT!_$-U^x-LZ6$=B}KIXk&H^4V0v*_cd_j;MmL+Ty1 z&(U_Ek4h3fvYMiX8xY`lW4)J3hftOx<**n%4zs5HCAX(BuzUzo&w)em3v!La|8$;X z`WzAHU(t}sC5#o@G@vwn&ns4PmH}WtIXMw-Ul0gH<7#YPg9V0U1}~=OH0Y|-o`T^B zxWdx_2rN0=3Q!F`u{kgXgUSRWj8?Fi6qW01_}A>tZ6`~&W&dGM z>-Y|G(>eE7-$G>~S8VTdwGO+@HLLgP{>N?wBeK@%qisEaj?jLm`7FNS{2?0(1Oxjz zU{`mxWFqbO=ihqQhJ+{Hy*qqzsqd1rM@9OW5AMO@?^$c6{|b*kG3WW!NKRYDKb`mE zY+s)fV43dE%C)5{Gn8F@)~uadWli^coqAW&{;u;~`ysKn>$|Ei&b?(<3h3(5%hJ{% zkE|JNXVhv(Y}1SxRoBdyG@Fc(%eOx@{#}GyQ$4Bapk7{xm`={)f#ts@{$!TcRkexh z*Ov%8-T&Au{p@L|njWX=O7mTZ^nb}Rj;`BRcC1fI$)2XC8pU!TEyua@O6yqFk6i~h z)jbveX2$Dr{`&K&UoNa08SN1Nc7OSgGec3*ZLjVuP*0%I*UldCA z%}BEK?6K@9NeW0BG8_D>Ipop9Fa0g2lj6@v#R$!h`onj#`SY%A-amBBtxw0L_{Kb% zwZ7u$bpD9PQvOM&*RYhcK0ms{S-^ET;$le$=h+E?z6*ao=C{2{ZWvMFOP6jq!+JsX zHEYy@YA+?0reA)mYdAD#C1_1P6+N7p8oOq$K!N2Q&ipd0`o=2H))zJrPm;F%n6c8k zMBCf0RVwUDv%{yYPY&!#tDTlUg|#H=t$Iu>^=Db}L7%2DQ8w{2e1*R&+h6YUfG z%w7aBU!2eMt7SO$ODQe0$@MNenf>@@g3FZ)Gpk)aGVA#`CzkyXXsR9XalmyRz+4Ap4dA!(1tm(p69la$7)P++^OH{{t8;h=26sB>$*e9}}akQ|t zxU8_OTI9JQ9Kc_AK^wOf1*E90KV)xOE|$sD1t#^x_HSd2?US(i_}5G~K4 zue54`{uYtvZ{`@gY-{-C_pE*^hbDGnj`xn?{$AV22GyW10t!#71!hMi=cPMpZt1nn z``5FH*|sJ#{`Hnsax-+cH;L^RSjRRm;%1NETn)kWop<<~itZRZ zs;b_hZf>2qEID{M{$!_G?pcMSmns?pRl5~;`o)BC2L8|Em8>v6XOnjQ4*#HytqsgQ z^Gl8O@*N6N7XnyjyNez05H9;2yt1!Qy()ERPus0e8}8Laelx~8Tw77mI_p7Y%g4J1 z7HDraH4YTWxpI7(U~$hM0fDmF!bz2uRZlCLHOyaKiR(z#e_ieQ;o_oKNu6P`dK0;; zQ|G2!>fXB|=~#Ne+Sc5V%ikq8J1=ityVm7kKE2KNrxski$8*X-pgLCe_g41!sU;gy z!b2h+|J*0E+5W&9UO`4K?AtOQwmm0)ugTWjU}5*WOY6eUv&;PA(sCozSIOO&y7~3+ zC|63L%Zf_PDj``@!;Ls?A%4M}frd~246FTn>yGpCz6y5C{~f2A7q(($mlCL|AL=hy z$7I(#40T){RQQ&x$2l)=y5PO=xSZ=h<|)TTwf~5g#1ZV2JJEhHRPt&VXK@;9R3Ur+ zPWc;m*A#s^Y9O&IHaqdHU|{TUztc+{#}plQdEn$yJO@DhYDtpzjCiB7Tw}+L*F9cM+P5yE$yk#PA=Dnffkp&JgDX`%^wkdnq$y;-W%W3|nHC$GupYBO&4b?rcd#kN_u_H}BMZ>_6eNZXvW89I= zuZx|HSi=P`rgBt#npGa3Dq^vGSyGdj!Fq}&Uo!m$4;C%%l-gr<#N5-_=lC@~JiW?azd=H|8IWTy)x$3{EENo3Sfs zt__=7aLiXCaJ$iouP-vL=smF3kH7Mqz1yeR2kq*Z!e*z1ITC`|L&)oYJc`j$1aoiVT$7wD_-8TnE;; zuh1IWwq+zHKg0ffzIX0OQIO2eBi6p_4?jv=)PBKyRM<*wk42WyfdhP&>zF-8okNlF zL0W%Gm}XJ$9+N5u$SbQpSASf}+n|0@c< z9{zE9uKrkV%08c93F|QLmLD9yjakHAo>^VjXo%O4)8@k4|87QQb_VPAi>Ejk_R>Wu zb;6b}<-_i4{7zf?jGblUOZ&q*+ioBHpcED{qjUBAX&GOPb-9!CXGoX*=wHQ*KY0+q zIYqinZIhqRK{=Kc`)?c%pHf}PDy^XFYN>Np@JRTSLwX$2Q^y6&d_BGOf==B1F2}!; zyTj)#k1&fd_xJU-Ww@eNK2TwGjJUY(WW()y38U=o)TzWNsuUWK`5ly(1$q zKQYC8?g>%DwuHIA4weTmst|hj=$V4%A|+1YH^JtwgcPQ>z1aBWdOiEutD6=*NiZ7y z>;C9MM+Hm%(3-JUx3$;uxr#Gn_YD4$RsFowy!q-jCxKswwh!SwbY}zi zp=bOSj_f%?%h=^>#_Ap?cgKlsH+Z`|;aKLQ>3T($f!e1EGKNjwRq|h~`gMA^eA6$( zyRDziB5zM^t2T9GS;%-6=&pSILc!sy^M^B`fAmu31mCni;5Ey)&i2;apX}^D&v^a* z>+YC}yd_ydupqizJds6pFloStYh~_Z6`uFo*03)3J6CNa+7qtw((_Qonpu3z-lZm1)5P{Hv~;Gyt!xgxk2vuvJADwy=Qunq+1 zH561I-hM{I@$vT;#}=6EsLen7rQYvVjo5ahby_RKrg_YeDSmY_TJT)p)0Frlj>|uxE#~^#I?4SYv{|Jq&4PBGyi>c@vU6B z@Z5b3{TW^h4y=m#ef1Bg*O!xJ(Q1ePRD`D8T`iZqYkjG}X0|68PF||xItfpDr z<4(a*zoshpJ#NF>BtqpF-JeCS6*UGUdDXrVA44ayP2D*SlAqgfc!&DLSU*mP6tt7^ z)m~qq%a=0R^0x1IhnnTZFpF>-l})#Pxz2f%GS2=~tJ!RzTN#P(p3uXs6P(t^*? zvo~@1|6%GY!=h}rwuf$zMkG{Jx|Bw`6zNjBMClIcZWu&Lq(i!M=mr%*dT0n`~CJGj)R7|uUKoH^*hRZ(T{9O^Qu#^^S~ZAQyf!D^OET*bt^mJ{7W>a#BS*M zn?U^m3q+c!_y=N$$!0vhUQcXqT`U=+CgJX6r{SOj&N+@eHd@c0VM!?(OgYy5o%=8z ztHWFx(?wtG^ss3%z8sQns$o|>73P75-O7#R3A7|5K8&26_P9_AESYfG-B?4DZs5nayes86T9ulH6^gZ34^)-VGU(X#l?D{h(HhSA z$f`-JqML_oPS_bVQbvkHqrD^972dB^ofowuwI{?IXeJxmqkfnkjUJp>(^qDB>g$GwYJPR$L} zjiw7%bS@!={uJvkxqJS0URH@;Q*w~Bk{l6n_gtL6GrVlkC)Z5*xP|`sP_dFpbE~~D zv;)e1KPnZvAv=J~4})(^KLO#~d&SP`Qn75|pKVEgVO;JKXII>XNdBV1WmCMqq!lX= zjp?{zyhHkvJ0d#33F))nCr@3;{@$W%Zz$AlNUF-)MXWu{pv#G_%;Eas%}}c|j?!(@ zy|CANk`rT15i@_uNblj5M=Lo#B3Afs6BL+Dgt~{*`$=~9`aj}SQZs4>%7=*a%3Zbm zURr8#GdS9NwnFkd;KA^yfkx3s$@?r2NN2O~efB;*HLj?_ezXip1tIu0{+AIP(tWW% z98oN|!4z^`(^=F^(|`5e{7Ty1do@xnBgT-P8x&(Qdvsm9o@8l1=tRfS0F!+pX7sE? zi$qC2=xll=cA3TLH}9c3Pky6Z=O*BmA!f`zH)c4_Dr5x|a{Zxk8KX?m+%=P|AJvT> zjuk+P5YaahoF9E~>MT0!aa*SAqoKyNS8<5=u%Q~ECqmsHztC$@Cwt~0yYIrEUoZDF zIYrX{j9HC)A*|@D%jmmH35)6ul@F}jCDQ7sQ)hvkFdw%bo$mXKoFL!g`28LtsHT?8-si9OZV-{|Y0!b6q-G@Q-a^%!2?K}ouUH2Qoc7XC9vaD9T(Y~mCLg|NskM~Vk|D54U%3@aVhKz+^n`yXIm7_PQvE0v|Y|n%v`zU#qmhT^{dd)Be&RF;hWI;Sw7(w5z ziY8WoGi_vL^5jImv(uI=PtnqiC5nt$GyaA4|1F7S{J-qUj&jxMO}OLh<`Bq}7Wx?C zZj+H%j4sz@CVu5KXjh#+NFJ_R;9|idbgx)q3C(30X$X8%q48c!2un}k#fnztFZBg3 zP-OMmdg7_S+IYAkL1K*`9JGwv`4#x{PTr);Ft&A-cCqF_ne=rq8XDByL>tPm=Pef3 zOg0mD)b%yqI#r01(QUzSA(G!XTCx|;4!wqaj~Vo%j<0YQgbjS`%p2Qd4pR2Cs5772 zYF3VqIEg+H#+DL1p^lroS|Dz`F!R`Mu1{-|0Ied;$Jl_t8Qe@^Ed%?-titxVd%u#v ziw3s{V#o>xpQ+mUg{7}o-jvmtTajn3`C0PLjsLhEIiVQ^+?zJ36vtRd9T@f2OZo{z z0=AQ|6~i&2@bNHRT6X#Et2-}{0>5Ud(gVBDO!gRZw#q0@g&m>POm!o7vS6#BE9L^C z@gBB|4lar=bY;U+E&cH^dx3Pn)UH-vr8CQoR)^xqb)W04^;^~uoTaSY3sS% z@Zajqu{Md7`(KYCR!jEsG#^=h)X4U`pNEP|@gD;|K50K&Y{2tXRD`*8zjlXZM#pq`(r`VGT816zkB zOn!RaYGnyoStdCM)wDEfx(AvVKAgY37qNlzG?%anUAch+4EM4^Vmb(C2Ll-+4s@%y zsa#Ey+tS)!*_21}q0kri3>`6s_caIgBxpQac5>mlg)~necW23$m8>1hU5aP4QPXX? z^Ge=l^tbvxzQ5@j!kK~#JyeH4R+nvgGCsTx**eWrajMS5|A1b=InX!oz$5NcTZ8qm z4x7%>a*0NoS9+Lre8z)X<|uT_9&ek4%&!dKMhmj1Et)nu$QsykN4{K2TlZ}$+4Vd! zXe0fEG$Ic&*R(=MljvHu;OV{$MxVX9NcT9b=F094?m8Y0u_|o8{L1j=WjdR$VZq+@ zran$|J>!f+V{oa1d#PbXHCFo2g0kDoDY3|Ief3vdLVwAPH;a7(8?N}fDui)|pL9Cc z$C?OEhlKaCs9mpR32O;@m$7-2-Py|(bqD#)BX@53NZ_I}uJ238=(k&y9RuE#%eJAX z&bpV_|FZUV(-tG$JT_9D>hN1jCRc$APd)5}ylp&s`GXUP zOTq1OGQE^v;+KRJa#lC@fRjH0<67#@6FMh+e&~OTk4z?99r-*`?pDy;2ms1q8*B)N z(dRywKcc)bJsFRWpC1*s2?UZ|YlBz&n^#vpheC?C&Mq$6jetVm4U#kt0T)a3s}R5o z2mpCkN>LFXOy5UPSqn?I*Ni14B|dInZq!ir-9Vz)`a%4bnzzm!2lrOAYK`CW#*`V- zUSTkj=odAI6@$Q%Q4!$?AMSBprn9#9);3UJ9=LDYMOYZuxP3g5P^1vSa+zPv^^=eBqsGf7p54+*o~$5M#O>B{f#@hYI1Y9*W8K;Wk?Te>Kz zQG|ZfHCH;xxgod$BUDD)uMO*bcqU40(KT#Gbw zDur;R@O7HvUuVMN@FIac<7wdxC}?PwMC zkalyH>DYlk(gtZfG?`5^d9)rBGCcxeT9My#vGrI}WXO=4n>Ug(h5Wk&F+?SI@&-ST=V?oo*K%b&v z&yprF0E|u zuKgvBsN0L?ZkGg_tS@P62~8V6Q-gsK1&UAeS;*<(Y-3|w*7M7q+HXTceI84`;hUD{ zA@F3x^{{4s@IcB*6uPRWr4>2Us_A4Q=ngfDd;SQ$t$j4&N4^>*~WZ(uOT(h zL}O51LC;~Lq2t>cL%8dB@B=2asb_u5*0f#PlSlj`^RiUMY^0xp3Xlrd5*-+?wc`pL z8oNj8#Xi@4s6q}E`JVai23tqyCFh+l+hPoX;Y3kXeo(D8K6idwC0$PO=#$Rq&&d+8 zGOIXsS#l3p+RVGZQ`89*_H6>MTdv;iFP{9^|B^Hj2I$gS9t$za#L>~n|y-l(iw>s7~@nKemcK! zIk#BW`?;Nks0dMl@aUU1AtVB3@C`~Dxa06J*brjgPen2^p%0l~!!b`f5PYYj;`jGM zoV4$CHSL006m}7I3LRJ_A`y2 z!F3=GgCvk8+#F@yngI3$4Ub8{+y)anI|10Rf*=QC^Q3)Ga-0wNf$G`a+E4-*=YSow zxOvXgK3Yo`Y)UpH5rbQpNFm8;A0o?XT1vQ@ZKz2!XHl6!dutUvYz&D9JY6VoC#k8g7}=D>TgC z&?;ilv9Waei7Qi0h+PYpIQUZ5le994B*w_J2n~3IVU|qIV?f5=MJ|gPc&h0V)x)$$ zC>}0e*S>O!s72m`NHe76>ym3l)n^+xHQiU@ubiobF+hZBv62n^KBjF=s7xjmu?w7u z_*R!dIwz+WWO9xQw!A7%z9d|j8PVf`u2E_{T!<51(o+;t*I&FgAvd{u+&>T|F7myJ zYG?1Yl$_1P8%>z;HaOR@Y|-h!a`>^i7{j+`iPp(hSDZ)uMJqdR!AWRVmQ&SCzA1sa z4-X=pUINn(K>Jx1f5pgwgShoIrzS)vylRIkhhhf(SlPFkKa_4{iWz06$94bR+!*!$nZ z?=xYK>W3vfgq8~HAK;Y_pu!%1{7_W-!iNWlcW!MxjF`L{N&m%NZkAE2|C)*MN3T4@ zgj*-)CZ5Dp&!bBFg5P$U2T;;i2Dv*DId9j~96EVoa+ydWz&0!fFarLu)&cCxFKKCq z%fTe5$UKMk^8(dN6tVv8`}cVj6~q=6R}M8_zR-dTO1O!Iocx@39rzt(ok(#jK6A7R-G%gBbHwr+(T^zaD7B^0L@K3 zC5mvAnN|cbKx092M@nQMhqlx^x*%QN7JSebrWVxq}`cioAyw zS^ClqV^_qoOzjmf8>S_^zq|pdhx%AACgegL4LTmIWb3d@~837jrUm??G(<|qb&%&I`ma|Y0#P;R>Va{ zYP;m|!>#pW!}?ts@hprFca3|l5~$aUBHJDl)W>ZphUA644ffo%-$1Y`u6a31PU3k~ z<5wwv%DJ!zOr4(Fq5}ijW(TUzcx}CIT80TV&y1pi0yf61aCcAhw=dZqzPa>WvoO?F z=oNudmv(MN$i9}dTk2&&o;C;IkhXbcWVV+Oy|fe)zu07g+`+k9aW@2y6Lh60AV2T` zC#<%{@@l{txuiuG-k+jV_v6$#D%bIH_n$Psn#~FB_A_+AV{s-s&i|PqN-ZobjAAaW zgL(>-N~`6Ki_*4bH81mvi}5HaDQoq=;XR;vt~dl1W_sag`JUVrzubAIlqN6~ z{r-3$Shj@hDu~72a=sT~Yyj=7ZcE4JMNT~U{2M7MFl2XI-LJ|?-f5c-}d-H`Wi z#9?Tlx%2Z9;iAw7qIb1qabdIaa^ZPam#o$OsiQbfC1x$Dcs#GCP!s(0nH7r4C*&DE zY8|7C6$^$|$dTWc^+bzfPVp-U(+=Xv)7tbFOvkHb`{nFm+%MuZ%F*XtH_6CwQofm{ zDHi=X+#@MJN%#m8WnFTQ5{A_Cn&F}l&tJL_Kr^tS5h;ycw{>fyv9sPuDlRo*UC#Wt zPlHA3&zl>^%$s|JYS_z+9k4F9Rj;c1PbTPBe+c73K_QMTH>UH5 zx6P&_c-*jcLfT>>2S>l%= zNep8q8oP;o2d(#~mcaF~?@4~a8HK*R24=l@`N0}6N7GLjN zFjd`53LqLB3fy;9xDh2Hb}6Y4d$?}EYG(#F#$;dZuZQkyMu)uM%7Ms5^1o|ZinD?9 zOp=ZwBDjYoes6DpI5-r&9f;a^{jScY_MocOqI8igD*qtXImqI+Ridgs?FvEH0MIA9 z@CtN;(%tKukD3%0S$Hoa;nioXj0tQ~W3k`vS^4vx8yO$Y*HTalHXt3qK-~bqV&-t( z<(Jjodmb8;4hPuosZlge(8WUzhDG<6GdvdKIlT|Ma1H_O{YNrh_oD$P77mVH_R7|v znwqEcuH(wNE~xQpaFgzZ{Q2o*Gpeu z3YZE(#s~8ABl^FTrd1pT(Jv5!=+O9*3wp2;026)mTapQ^JEtEoA5KZ%*kf zN~2CE2!$oLH>@qh!38I)SPGD#(i0YZ=sJDj4pKC{xaB^QjHiwujlX8VP$6&3xyf)Cu4$Me}Yl(o83&nb4)Um#Xp-*MHA_ zQsguG-Xb@$8-I9w8PbRl^jfw4Btpb0xS~a$Lqq%XXP3#6E&*R!pxEqaTW;x75ekE# zD+5n9DmAhVQ82fH6?eF_1of!M9p}C$#(w4yrL;q4rj^mKo7xW_!aQq^GQx#^guP?s z?BabE=+XY>jbmZ7A%l*OFA-l6^yC+{s?Ew_hU1;Ur?cpI>YZP$t6GKwYTxvAIqMqF ze=yu?*R6pk*Va^Rm)pFlbj)BJZuF2~R&F67dro=av&NTCLwh`(S#{%0xj*u}#4u|j z?IJE}V9>Q_bu)1J+nOO6GYuH<2E>b)p{Yp|X@-)36Y4_%uo7qHJVcQcZg+eAt_Qag)n z<+ZmnfGfhsRzfQ2!ixmaO|FMkI*I?KfHIV}0Y+vk6}mLL{=>R%{cFkYT71!?ogODT zpkhMh#`)ZwZ@|vi(+D-mWgwVMlp%gLGm-a21+KSbgbdihCXe7@ENr zx`i8ThW)cYFvJczF|t_{4a9*g10rc}Ur}+Llx&5ZOTmRlPI-7l_dSd%ZVv;MXUB3Z z?7_sdak8jQ6np2eIFAU8Nr|J!2RllQxm3tI zBOT;Q_n;8S&z4_7sqYxe1nlbjl)`X&J|^td3J`tLi<)YEg&&l+FU!p8`WAPewz%l% z+wX6ag3+W_8m4^-h>;OaN+mCSz=P_p+oGBNB0m!U_uFhTXC~Z1E69qGKtwp1`nN|0 z+ItIaxhLZl<16M?#}(|Ylm%~z2UHe2PWr`l&FGbmZk}TrWv$l5DmflYzbRzF?`J4( z?exr{vsK2}e1V*PJ{T!IDrLKwx$e{gaMTInf+ga^0Wt0TO$N$2u`k@>#B>RHP2~ZZ zmyao_FlCmA&1{(s)$kj9TT2}Kw5QM-Xsxs=UH`Ma{?_^BtsG(ywwhZFSUGMeb|p;) zXu#T=O1#8Z$OPiYyR^58-$+_CUYYQC3XLoICDo zyKL$w{TYBIFVmPnl18cKQ^~;I>Z&d)k^H4wmP_#Ihax7WrN?_jMI`k-;upVagXR(d zLK7?+cx^AJHix#Ik6FD?rzqNET^3-# z_+j0dq!Vm??`vzhV`ipaE^N_o;Qr%tX!n}e=ddd4th3LVW3Qd=bq#$7EDCp3u4xkZ z5&OJr6v5f-)=Suhi7r)^-vuj6R5vga)oT%LSNL2fMfWd7tOC6#Uw~mUVwE0RG2mZ( zhMPjH(9{vXpTG?R?foIptmG9H;i3+5U;tqlB~3Sg3Ts&wBG%Q_1zCs>o~UM)(H4W1 z9tBv;Ej!Wol{6?qfvp9|7JV6`L5ZsGP^@F%`^vg?@64K3A5ZrSRoHt;OHm)piZZ@ zw}`dsFBdp3ktW|>1T^ImPX*!k3BBgT^k9PUIGw;J6A&MBSuj<$!uuQj(3Pm_%ESVe zQ)POQbOVvncPSYN^Wrd<0s=EeT3gW#?3g&1@lAI!OJEtEDQ$jv)hRbcuRKD16fmn( z?cx0|1CW^Jj4$a2e4A8)nBddsM@HiBrbdMC#{39qFd`Hi>w^zD9|x`f!Hn+3vf7DX zH*EQI8VR~}Wzta&C$X)zisO0iri*tI^OscnPuku@LLgL<{9h81dW*;4P5YmY`X_r8 z>7fg)i~5>;&2_21@C)~Ae|MTEsY)mrASEQF#dcdYfV50)V# z@KQKAxGJ@lIKU*l^QH=ir4sA%(#i1&K@sHCPh7q2SwOF4cCX+@mLsy-@7jxkg2Du) z`&EVCwmAa*Nz@?LhB+`nR=wCDRqA^92tTre;jyxC$F<{3umabs7BgDpFiISF; z4MD`>MZ~yrUx$x~06UYwT$a8+5#K5^k2gPqie=~ zd?sDX`PvtJF)9pHyYX47QZ8ik!)>a#uLzKWh69-dN^|gK!vv`DQTx?maIa~D0r+xa zdn~x6*m{DYm@Cq7OL(l|`zW7sG$Q2o>tM((aE{QL8>{VbO|2?tKb(FIm5ItV;i2J^ zkhos6$6Ou_fuMIi4ua)`Cc{VpMDlRa8`rf)F>vnjJyXt7sd0E(pI%hI)CFiB0+I=H zo1?TQ^%KTYLPLoRH6=s;3OLo~L-T$j*TV8NU#C3{H;)|m8njB=Or$B_x^sldt21d9 zIyq#CSuCSYTfDvHM#I;~kT=_E-DsPE^kJQT*zS)OT(dwXU(Z8aL!kaHGFly9f4sA0 zA422u)D`Z4xznAP=RMet77FNulpR{tlSs1+En~SYTIaFz3)f|ODD<;?1{ER{oA78+ z&U(j`+HpwXFMnfFgFl;NZkAEUU#Qf1){+N*kSU)&oWFhu8Z5PRVA4g&%x z-A_I6rjZI3(H6*il`fTjOeWNF*#&OJ5fAAGX6T`j>e*(k9+zs@qSJ95gG~%v%d^_7 zQT%#E5m>pF))rkEiEyD~g`W>OVRWVU0twcdPhG4v3wuZde&!5VPdFv^t+~f%y>EjP z(wjgQ1POkI82S7x{t+ixVvN(-0G;zw^t{Jy^coGYVrX9f!l^xbAo)4MG+QcG?bdc3 zN)fU2I59n?;!cQ4VLP*I*TJ?}6nbXi{!9KMh7*{{Nl~r$Yy1JHM{EL{}SgH}mf{CsTG&=NZiL?Bo?Ku#6gVw$q6ip~w z@jrk5SlS!Ec@qKrHS!#m+CS`64~&4q3zdnDx=esq)zvH6?tn#V z*3cRP7$u~KSvaKUNKKyEZCEvRshaQ~ha&$I8z=8^2b|H?cO5t4v@~iBUe@UZU85Oa z65E5G^6(f9wCI(2rpQufd6qtsp2Crip=dTFG>L9m-voAwyw_(gJmI9(-3mAA7!{KAzm`YS zHf3X@Af+WX^q<(v?+G&5v=tpt01R8+!Kpz|@B3#Gkrn33tvu4|HY2A4kBO)F!f;%X zXX`Lvh{#;HoU0!auy%(V1zXySiqP+;pqL!)zicAjTos)_=}bKBl5A^2HXR9{ir~q3 zsZ+{9c(@k}IQSB$APwHtg2;Pc)#%9sE&ZG^TQ1AY49Q6nC(Vt&@#0;s=V9@x?j4gbZ{QE9NKg9{YMc zh=E)1crP(2395b0t*Idc1_`o2mgcw`E7;)fafL4Tf2ac!W(*><>h|XQ_jE~sLhp3= zO9#k|>DwH@J8mcesLpY7z3<;L?+!{BOMCnNmA+^>!3B^f76;rIpa1>FR8)FooN)%! za=E}_1*IQPH3*UrQMSKZ5~;|Cluj519= zGjh3;{DhHOkI!O(fFl4k5FuxbMj|1L>QoIT%@t?rQ8^Z*A^jD*>cU#;K=3$POH?Nt z5B|EmB0J6R+`vnmW zy*gaOmRTZJ5H9%P_WG0;1eZJqaW7~4{~-Awp~T3*d8P+g z?T)s2djL3YoZHR@HE%Kn_U$whe$N^F&P;*#9Lj*f(#px9yQC0hy$B$^$=Akm=H^Vm zrF;~)`G8b2zI)nhipT%}&FP8h$)@V=%P)RXw}iSqS|OeJjLxK)q8>3$Jv@^d3)kf<8uzwHgIh z@X#UXDKs+Pnh4v+i5@rjLtSbWBjtG~A$d%U-C2m5hmi9-0@$-FkrMOVcNFy>hF}rOZdm-& zTq*=t06vgbXY_Ws;Thjq*qo;>o=;v)J0$A-Z>5f^;CvW^oIL;}>RtTD4JA0K4G5S0Gj3icg z6Y@GT*Q#Z(ItOds{kpnj=XuK5Fn`$%+^cx9?G)LqCm?~nx!a&`q?)WIa^|i4L~W74hT9B|Mt>nbI_v=dZg}UuQ8kr2)=4u7IZ( zQ36DS*0g;Q+6p0h0eaOXdi<UTIw>?HjDcHpwQ^G@V|Z3k>2&J7dC=_UOaGn#f`(B$ z+M0O#5D=rj_+Q&q)yL-ITJ59~vrV^<*nw(&XAZ5v4X4YfEC^}#__O;DBv&6?d<62u za3=K+-qS91je*fRZCZUq%aQF)JoY%Y(auxGwt&v~;h>Yg+NSYMkXu zmOH;iyVUm|;XY_cK-T!>Z6At0M%y(H=vH)`Qw_AS#HOTuG$&tnsGy$Re5Nv8fa7|{ zFl=Y9)?|;y+F@MoT3&%A40|nXH940>Z9${cP7KeW4E?n>H~t5@fxtOuX-@?bI(M^M zul#^HH!7jm{r(l6_pQIvK}RbPNR63Y081yiE&s5Q>ibU6@6wbaTRo6sR^y)a>8#z@ z9@4TJlhE?n#ulw*ma3ZUF8cZ}xHo@OAV6p*#jxgo6M(veEz^6?HuCJQG*a;h##yZ# zg)SzP=Akns5GZO?wEgAJNe0H3^G$27qMKjRlMUYa5P_~_Q7<$#&tFleiQn$1r7rsW zy@>m*La*RHDL#kBGLo>6bu1r;z&l~TIZC}o@@%Sat1kg@t6}q6eU)FUdU|j7&r}Fo zw`hE-qZ`Wcmrm535cu56obJKxIk+rJafp&Nm976PL4*AQ(37b^+QK3iaICVhu=sN6 z1Lp5%0O0NI0B%aI^5+k4R?plLVIa=H7Ezcx#5A#E;Bmi;OqfiU1 z7D4J{Rl%rJ>Jnore^tlEg*{eDX$r_eb;-~ViOw0Mf@$XQQ;Fn5~zhsI9f z*>94Uzy_+*i4M5ZFU%Sqa*yr%t>2b+k^O%@I0FD$&*+^T;hi7o(Y6tmn#JZH%O|+< z@1Y2_dvseqF@N63!qf%F4AHTt4HIy%ix z%(1Hgnzc1AYW+FWLC>!EDB>flr67A4aaH6u&eD_VP#&h$XQQh4;{y0Gc07}6jSl`F-ZQe;(3N}(ro8J!Ic2&TsRM)f#IIxxzY&_7xH4JH+lrx0p_~u5R0eel9wZB=2;v$e)|@C zbohYTsXY}~J6W0UJm!X{yeAn~zcE?cA>3PHR0P>hJKw#!e>EyyW-SwKaGZb-^&{&p z+B>0jkuGC<-_ZRxi+{zrhdiP4_(&Nv;-odtW3?&ipuDkbleWlug2V@L%46CDBl<#E839AwGGLnTXZ7pjs15})W;hd2^b9}2ezE$%cd?3(Ssgoi~e3w`RyVtM;)!}i$pT0Rh^V?SJ!$xcYZ(n6%@cO_4D zvXPn;Ew=CDBrOnEEp#;5F?aFL61dlXveYcmj~Z`ZlqRW)Mx+mU_P4Ov?H2SMN{2Ci z#0MkV_-j9EHMOov#u-Xi?xhTe5`~hs%jr7{VHGl^)q zE_f}xldXEA8ZWB;D6WSRrbayd@OLp9z3)@ra8TWRHc~^YXzsGQH_A)IsK7_CM5$QP z#m^^N4kx>rc4!C$9e$pBp9pl>T>Gn|4plN-2Q7Wos2e+AtzNs?CWajMHTjq;0aWOJ zWMtU??CRksLK^b>c{GO2Zu2n$NqEq;?f4RO;Ms8Kba32-OY?ab-rbF2;T*8 z%k&b)?YH<`3P2SyHD)~{GRW)AkzKi>g|_&ewbh~Q+(Jd&XY3*tn{mn^5hYav)jQ_u zK&`2sEBWVBModig>a(XD<1P2dhu>1v~WSd zsV$qV=qRgH!&a0nvCoD5{Eo$6mbjP~8WUKlu3sKxvPWDvKy;n1^v9mWfJ6i? z_GFiby_+umfXm)0{`xKK@aL$f%@{4FF-cw2 z8=r)6=Zc0=hoH*p4lse`l%H%>)##}1a9GjmI|+HC4Glh*%pCVhUEaj0o&l%=ZbK8t z=|=Y>TqxmgipWnpFZyGZOgx5IDg}MXH4)UVKV=SQcG3e$1z&2x*+!nS$MjAIFno;+5CQ}yTWsb?L;+rwp3K}nFfya$|i7zKPi1-!L*}8jk_UQUh}!Db`}%{Q zr~m2zCV-!wIMLGC`2TVy8d^TdAaBf@z1VZnp!NR=hNBbgb`M=iscygGSoT+YcH}*U)C!ZxDz51Y&r8IG2!ydXAr) z``DQbDO=Cs<~bRr zkYhe-H<%G=XuoSKCA?I}82XM^Cuwy@WjLojP7~{FHU?dVSYd8S%k}f1UbTHAyA!tl zW9YNtx7Bd$-Lb`e^1Jy!7J+8LmA)Vjg}6H7+KFNTfX$@l0CBw;kaszU5UOeUHOq9L zsBR%+0|&qlT@|y1u;Ght{H&H2XaRdx9mBdF2qCX!M@(CO5j)4iM?rbA)5!O(<+NI- z6!+6Bn+ENl<0A_1^|}32zV=-*nqwKQ|FQcEL+^uA0(tsRH0&&bN($v*hPUJ*rQ2&8 z{vH*+EATr@eZ}Yvuw||Go5vnyYFw^+wM&9Nx*Y(lpxoJXoM7MdWV=!~{In~5b>s41 zG407$M#KGd3ZfE^Foj(*-14Yn!SvAo-7zBn*)a+w7i4T4xVV27_$uh226p7IFe(Tn zfoAu#eXCqyGp;!_sD~KxbAYSUE7Ri+P0ieS-CiOSxa2TmuY}hXOVxc5rA2fC#?P7) zIU5No+MiRGv!2m{UPX|ODwIbijxTG{fg-yn*!0^vPR9_+vZ5MJU8iW{@$PaoHGlb` z-Hmmv@k5)fBAyY}D^P&G5Es&)19XBi>JLtn(@c{aaN+TnGrui`wI4;<$xBQ+Xyw&E zglzGO^$zSfgm=-mJRL32QUPUF$93I;b}H!jzDY<|N+7;lja3lc8Bi_yO*~WSBIc-l z0-vMriciLalBpZsp_9P~oOm;^a7Aok65yI9sXb{zJ%bUutOC0$SlXy8b!Af>AZ`HR z`_h^5J1?=A>${)R;SY1|UH*HzFL#*5CQf-L&C09|$sKdL7x%>VWj#Ld6hR~Ly}gLR z1^7r7&Cy**qwSe2bv>IAuK#?P`BQZxK?;)()=|a?KXTj~X+7^0rvI^)7zw!Aetbs% zJDk;xA0H^jMUoj4EewU|umFBN^76(*i1yOWb*m;f!Q6CbWA2MzKwB6S1c2Bq&BIr$ zB7jLAH#@T*y%G)4*DBWKR%bgqANxK1CCH%@a@pxLPMy|!RuX{L?q7}uU~8&p90NOU zuzdCAm@GQzKh?^Dae8v?8vjjc!c4G-zmqo7#E!q8Lt*#V92}5l1BeO85W{+p%wlvk zI!2cVE@BnW=J4{pTJ3k`qyiS4$L4}h%igi{jqp6@mor!R6)NF`9!NEsGaU3@gO9Ls5d%PWjdPFgRNxbSIwb&l z?tcpa%%Fd2JVpt*k_-F?I+IB~>P1Vqke`$hT3=~sO$pvbU_=bxQv;f@1qlpK-rw1_ z$fN0pE#_UuL3x5mG7I7X7;YglbL>kHH$6F*$Vxs&ahM`B6}>_Nxasd8@y<5t z8$0%_@0}H@6lkeYYztk~`sGbmDA!rMD(x~_m!r>l_ZU*0yq46xq;h}Ko%P3_v#mg| zp+|&2oyoEsyM#1}+6R(zUM^U@0SME2wT;WB6?&F3my66ai#DI_ZS*7H%@T3ltTHcV+wH2qZxCr$toR_JBa|Ioe=NLPA zyOeU{bY=LSnhqQ@fT|l9-(=(~XEXVo@8ktJqrRI7cmFQ;FND>*Eb? zcZnM&E$eEQ-LuddhZxGYfjQ-hQ!UxS=>SuOvUGGZZBx)V`yV~tP z`>lG~8&HA`tQR_!Hn9NtJ5v4e)nUAf&Vd`I(Ek|$!3yqo1mN9luvPrVh19*xPAS)5 zRB9e0R`{obt6h66%b5`!ar-<11+s#7J0fGz!#Hrd;&VHv=h@r))4U>cdw546Y!9bX z7OeTZ8a;>nJ`2#h=nYf7K7af!@ue-3xB|FB9XSXYIoCe{mxrBu zr5Udi;TPpYQ_k_|5a!fJJ-Z?VgBbfg&$ld6erdW2(OTYr{;gb~Smmz>bl%CE$d=R zhA6BMXmx?|;?VLIpo)Mo;sZ$M%?-wNmZT!Cn~+-F1r|4c;1w5iYlYP{`lDuA%aM2~ zT_{EwYs181^8G4)d`W;i&r7*@0fsq8> zrE=hl@5<4C^*Tg86x*|2HDPz!i*U)OYbcI^!T}Ri zAA*)1Xc*=TR<9UYtFs?HEQ8xMtYD&}=kUkNmFWGSOWiU%=IKo(+S)4D8mn+)|GqGydsV=zbK%m!2Bneqzx{utePvXZ z-`eLxH_{!_Agy$Zv~){{Al=>FrGRvImvlFZba#m$jii*bpYuQGop)xf`7kpdWeJPe z_r9;|m-ick3$MR{fakdBY1IxuB~(qmy7`x7dO<`>K=tbW)$sX`A?_AsT~@ZYu@QZT z1@xQ6ar#+JC8k#>U}*nX?86qI7wf={QQy5ZO`?vO5nqe&N-xS*^Yh?*zE9HZ0u?6< z*_*_f9t(6mv@4v(ZSLd)IwvxQwhb_@O<})2>%?ineRVJ@f^c~p@R`!>+;=gB+1rrl zL49X@W?_wxsI3HQj|395bgk1y_8Ya#7Z9K@@W~QSMEI;G$x$-2w9?skMGDm%Cl!7! z5p`tvIj!aE#H`a`ntAPJN8KXBT2uAuhKh+Ktfcr>Z*h*fnnc=7g8-HMGB07KCBg}jWikZwcqu(d2cwtK8fVEeVA|1 z%;XOn*pK0p8A z>9;95ci?0>5$yuke`wF=x0#@4)8}eV3u@}fn)f{)Bm^<<;?B+uHh-YR5bywh&k>~r z{e2Sm*VixUp?Uz|i9ZKGI+cKqh~@7vQo9~NR%R{B%E&0M#_jOo;p1-=#EbL+4SUG; zwn<$MP#yGvhTbVaEwlOUgH&{=q48|O{|Wb*=tFI_ z^Kv?fuG{nvp)A%ig8T$W$Z<4&gD)0NzqD$IRn#vSf5y%khdh)d4$`ApK3rZ$*Rd(n z9Q$ePvXj8lpn)>>w8{}&zyBD2r7}w?dKbXuqNy4&J%s3DBQr4hgA2sm^-^Z25e=%9 zr2t7U+!#Hi%e}XepbGz$RT8q`&(06Mg-dHRVK#qf4R8HlO!)w}-c*xr89hNj=< z(o1UJ5eT^Rej|=eaJ~=}W}pH;6O&u@0wf9#2qPZl^|k;vrwQb$oPyW;)kD*N&BwqiGnx9L-GrM&bDvfx@tV0^35YBPhqS#Y&Xc`%-t`IxMhn7=Qmww}1In_jcu-)xzK)t(Trc zz(@eO)%lDk2}Y0ypo1d3A?C8cf>wBgn{jQTq!AFn|E9`;nWPubdzI?_4#Kln2(YGU+ql{e|RG3!7P&-1U}S zIxaSY@UWC_I0*a(x2lGV$>?#=PUE(N4VXuwij6zi(26mndkB*H)hC`Ut4_-&y0f#LD68mJ^xEyq>In1bKW`Y#uaV_)I5-AV% z)&7EKB&W_SImG~Q@akKm?saqB*MY6BwZjwA&$M2KrHfg(7Its300`tCQ{?b&cmo5+ zMTr^e&L{5FKQk|V1U(V$jwW}bZld!US3!>;)T@3D&Ro<<^Z#a$-0;02hw6NxMKR#T zbPKfc*rcQtExu}MxG!G3cs~j)ekD1`{7f_W{$(D}>7(~UBikX+;u4$H7VMi1f6;88 z-?9>tl93n~#m~RQG_A1-dN`ng?@(>es^c)Zc9#!7)GZ7>n1EKt zZuY(y|0bCl8=Sv3A}Hh+F{Y7ei){uiCGvSv`YyQ=TgVTnalT=iXH$wZgoW~EEvb6c?`Qn5ftw;p;DQ4l zx_>xgf8|P+XwQrVOA@-Lwm(|(U%%TWd&~)*cHpviKfDBeaiPSo%r{3q(wOLrvI0Fj z0lPs=ARN**NI2;Sx-NC{|9^*I)uCq&c$5A5bzl|-?Lz^zTKVp$>yQAT-k}0%_mn4H7J zn6hMy?-hi$Pt1{%-E%QA4Hn@#PHj5Gh;H;fldsrhJo|<3(>PL7HdGwsoXbS=!Apv| zs{W?2UJcFUFGuUjxr$(Bl%J{U&D1?9H?UMte(V(fbdzpN5n z_KX>=4R{#=8{G>KiHoRDqtJmo1$m|d!;sHII8Vc>K<8cMgE{vJB_YDNalMo=eDM9S zB+N$T5Es5yQT9#4y!V?Id~(*(FWu3v_W!i>Cp)nKUK3&9Kcm)*?YrHayxOLrt1wKi zci$Q9N*7k?yPHUeqYb5NdrMoh&Pb(h@n=P-ak>?{7mh__e}N|LyPMN^=;imy{Cf|f zp-&*_Y=ZXot~4_zjVC~htw862HVc7#D#vT*m9$qM1X{0`8=WaCobc#FEf#<~Pf5?5 zgDBzCr=X8C(%Z!ucH&@EuLTti9l%8dmHIN|_-z07Oi4+hVPKGUeFVmE$vSKZEF2uv z4!Fes4wSwD6VutEX?&y!4Fr1q&Fk6^4paKa1O>r@e=;^1H|m@Nsjw+1Kg?f5KKoeh z{PTM+XmrsRc@`-8_)7`;LC03eY3?OKo}#&@CqM8xbxw0vKM$hm?be>9(P)(5fs|tk zQF+a%M3AOuhRsGWDx~O}SP+HLmwf%KA_Z4ks_Za_aez5&vVbYAm7kXmDct2z_pbI* zIZ#_kP)_3EDLsYH<>(90kAxx8vr*Nn!TVLFb>=!)CCd6nQjLa#mPm|ac|qrVsII@I zRadFRcK2hsib1PHO<}+Kv-N=(uNOuQbmMS8Sef1PJjJAIBwH$eJ@A0Ekp1VoRvEXu zewE~wAg45!@!g86l0S^ymV7sA)_N{??MbZQ*R|BQVA4#>%HU#nE7W?$tIoq~iAt2H zIg!%XDmWFQyi~RKT1LVf^89x0cgzkA6_FcUof#3yhO%f0j z`-7wykl+dMnE+@k$G0S=^7AZ!9T0&)w$U!+K3Yi?P|fz*&8>;3=*8KrnvyY4Bl3A& zmJ)pxuwCZ*V!uHF64(Y~dDkPpCgMv|W@ZPIl$DJqUKZr%D{GihR!mGxAoVPKagose zoRd@Yg70%mNpxLC@(#-r^3`)HdHcx_tbtnXYe@^YV)LyUCRUc zJggs&e^Pe8wsDB~`8)}$lZiaU*n*@cK0P4A|3Ruzl%6>{{j>QOl+T1!MW%hmRd$ey zpAJb9~Y z5oF@7R9V@Qsr$!53*ZEi;#_P?UPVg&MDy|KyTtwIvsF zL2Ev%2Gb61H4p~7Xvy#1+G~wHzOnbeET6Mvxx1AtAVQ~@ObZP{7Jf&VLIc$13a@)` zVKfx>-$z$Gu(^0e&G!gEYF1?fyg5_AvHE(KCn4#acp#A0!eH6JbyHY<&Cawccy;-8 zPFO|Z`z4-x2-&c;W=+3}>}r9{$86CL@Fr762p@K@bygk5Z$4%R_#-^mVnofN5b)^C z`9G^b?fAp8-!dG($gEV zLNNowLQy$edp_n-D=9D7@%%)ZKiqSRI!W=^OuqvgEXf556qGs2$q=)GF{Hp1U20=D z;A`DA9M8K>m4P8PEY$;$A{#SuzBuP}CZlA&?ovY+J=q_E}BSp>Ouw zIVJPX2(zJNTH?hNFQ0ti_IY`kq`aO!5r;5{`mC9aNl9-9y-`^&R60w_r|+oWeaKy% z0CSrEg%U-p#N#iHxjstj%xr6wUqW|sA!>c#h^$>$tUC`e$<1LmmBNB-arHq1B?G}< z{F-k5o_=wUq4Pa=Yu#1iHp;bcRjl>y(T|vzuOkoTWzU=m=^WXqB-%OJ=H0sX6WQb8 z_6eQvBo!k!N2J1yq#XlitV`%mITGqrD zu*F=ztCD`LF;7EVR8<_jpzE_+=`zfuLy&*k%4l11lkV^N7qi4H^s2F?=H$7EqE;F2 zF#~TEuSpm4Y_+MsKOq&C3(x^p{b%@}Z1W(|w=#R@a`%H(Z!*GEAv&V3>|K28gNS{j zFT?LfA_8*qSlLASf3-jYu!yTb3MRL<77J7rN`v&PbMQ6_G$w+2rT~aZVAuWVmU3Rc zuKeyDO2ZOMHX0K_wB&U+!P3$a^ybm>#nuajmS+8ZdNex(73>I*bXV$~?QJYjbqexT z1gmhbZ+N}*qCGu5fxD`u97HWZCm~`*CsXWg53uI7v#K-YslRei>-AKzAxlh;e7ELJ z17JselkKU!3wITDt!RzblC$}ztdU41tTkETkOx~IP_eTACB+4 zyhL;ug|%3zq~jCHBz4~7No@!bTq>AmQ49Nz2$n5Yi^aDzmD*aq?)s5gC@Hno`0+rc z)k_)2pvr8nR<$QT?;|)L!ub0ggNFY;A8j;!oR5MBZr6YAAH19>imJ6)XPahLiVNWC z{Z(Eq=F1VDu-oQc$ZcC0kx#ne!giF-J!U*pp-&Urjy;sG;n+u&&NDRJhn%X0pP##Lq-m$$~oVUU_F2h5`^i9JJnybL{or9xnMAScRXINv|P>W zO1@T%C?5Ozb*j>Hgae&kF*5?1atKb7N%xrfD`@&!{M`>WW8)5jyr0+a-x1tjANP;( zZum&v?O;F;65+3SRo%k!PaZ!}BExa9u%*m`aik=P{4a)ji3N1@-9r0mSeYmqys}1x#IoXfPcyqR%|N%>9XT zrJn0NkmaG^ALZb+p{uy3Qi6QryUrx!>(O{3P^9_!;_DxczL^V|;X~3(tz|{~43<&j zlj6quy=)WYp&x<99?e>%#rKCigf)Z7<)LLCjP#$G#bJLnXc@mm4)$`u6Cg?-5=D$t z#*2qHAmb1}7drff`$nbKO5#IlWEwcp5&3YJV=iiVKII3V!?(AL4j8}hC5(Mp*G2MK z*)+mh;ucnu@dEPNy!J`_QN*epnxAq@Y(3Nk=f;*xVYB%pk?0zw;myQYOGK=>2eD#a z%Z`g4`JDT?Y1pS<)MaSaFwf`@*Z=EK9Z?+UZl#M&hJN3YvHG_W=yu1+G@R|$?Kk>^ z>qU)IfrJF!6Sofw03V931U?ji#cY4Ay?ONbGt0XmXwxsKz1&}(6Il|WrG-E3{wo4S zoj~^&$zK2#Jx(j77jybQ`FzjdeFn7d{Q#1T2l0AFzzqa#V1}ACz+B!YG!9nDTClrS z>bDaB?`VLuFWB`lpqK*mGoi0JAZm{r8j^x);A^#6K=2Z;9>os87W&&<4>|lEJcJ(3 z`Xgdv2kXs7C#Ulh%6E`J0tZ}r@p+B^BO0i>N5;UA1efI`u$e-7Kx&7ImKJmeetlgq zdu^&A%jlBwWrPu3K{1ilnk%Dr`_d62CT>9_Rq_oXChvHLqd<+9MR?a{hAnq&JK!&Q zva6V8eoI_Ky%D92{da{+M^yWus4WHRR(qFd+GIx>79V9_E`#W(jN-p8F2B zFP+niyqivZN}MPKoT(1~%)>{&apZx=!SQkWDuJRp%tqC@v-dbbC3{ngH8lG_hyY*r zVyf83rM51ofvC87jJV>nKvoItTimA=W;})6lSjzgctQ@ZG+5aEQ8P7Cm{65d_v{WVT)G;ZAXLhb2ayJu{J z4ZX55DQ2F9&mS!|Pq{D8LVBwqOc1N?&U6U$=i(N%z2vhI|4q1dC`W28OyZKYP$2Zy^)ngfudeux7xAyzt5Pt z13MiQIMwd(;I!+!(nKdesIcsIqVRukbU9kYw(mGcT=%`vSJ%|c*Kc=Aef0qoBsxX_ zXPjvh03e0{#8v0=OEn%rI*%O%1qFKG&W;);`T!aW8bv{fMV{Js5cVwK(gSEc4uExa zJl!1v1hrp*cP#|O8x?5P7*O221yM9*(C^^s&xHSWUN}Y2ph<7OWVA%@`C(1B6mTCs zkR%>5L*w$8rx(e;{d6GvXU7{0yl+XsQs?3+$DEWS_kb1M&6diF%yD(l-J{^(Lk@(Z zm)qi3UCW3MACeZHi`8ec*1%Mz)0;!4}ULb(dg6z|XAO(QpjLhDpEJ)~R7s`8>F&{$W z-mD<{zN$(d;Om6VT6W8p@VYx7d*QewIY0-IfQR3FY7`ZS)bu|kjgItY#=EHlui`Bg zy9K7X9UTOrjF8-$2ffOQ{ul^K9VY9+quQMSv z=IImafJ~^XR@;1OO}_Zt*?_BCM>!euM2?DH3JygeAyNz^mt_c2Jt$n(a?wojSX=f5 zx&3Xl%CvU;M)Ysawoa#Y(Td+eXuj#-4EwIo8~0W#FqSqnPx0{K{iS%hT;TCe_+zt` zMOL^z%otg*nM=j>uG6Bbh|k$-DK#XCW6Dd}T%_d%ad z#VK1l1S3Y06cSR0vN{U?jYdg5N?g1cx$ueH8)?Paj;^_or=&CF)v3A+6JTyq?}nq5 z3_dtsFC=FL$dkKgwY9s)w)VWY@102>Mc^h%l7z#L|AzO5=UJj_{fYIbBNp-D`)=N_ zvI5Gh^D8UmopYXv+(Mub*sVKQ?y8$E?HHhLVo;Q;a&s*Q^<9EI6r4B;;i!|f<07%y zyYsIMIh%<}B91zKAlZ8xAQD$Cl>eJA$ZHMi2?7R8;izF-JG;-I>WR|?o-dRZguR#> zgQo_}>xNx1HC=b8e{Y9HeGGVKDgSFl5Y;439|Bx->HO~kNQbD~3kaX2J#0_rIvhUI zd^ew;jE|4j0$ItG*Q#tNHlhtH(}2e%eoEb-M#tt$KG;l3c}mNVDz`Uh=c&z}RY$F= z5Roj0#BJv}!Tld{Bj_AN9Mp7hG^#IapTsG;gkA9ntt#qTpHtGPWg#7O;P1~mcq%xM zJ*BOx@dSMPd@+>>Cnqnn6~mZ2D~g`1mk(k0V{R3?Iy$s0EU6jKGyOu(w|PA`s!!28 z|6aUGmO|7$uZg7Lj`++`1Iup2*-iRb4pqz;D4LY4_7?EHCU$m6{Xe4Qlwqr+eq6(- ziE}4V4z-vW`Tl(Mhh)aOL)p|H9h+F40RRWQQ$q;y+QO6(HJIB@n1(6SFPa;ES32YPWF2gZ39wtDsuKeY<(OUzoj4Twfa^e zu$M<<>LsES#rPwA>WqxnFo901pp3}Qm5@&I;Q)bG~(8Lg1eh3q%TNIVNrt$%BWdWU1k1qnb`{1C7uM*tS9sj=O84oYZHgJdlb zcdcygE=l+=%=H(G_OGMG3<0;nTsNjgzdErdU$7k*kXptgWv0&&>qh^syD*`8vS$|U z-&}amV?Z+SG=-bb7Da>XH>~VHQtW4 z4;jU0?rOu;Ep?`ZHmu)aHGr7{&wv zaNsUWz`>3T(zz0F&PJ8tbv-KA-l)^h!e_Us7%cSo1Ibo)PuL33>J!|IyTi3>m4GqW zk3mDx0B*1AIntWevQ09gbb`=c`slbN3AXV()i1^Bdf5r<*bEil-tTczKkt#XNa};} zuyewpuP#Nx*xQboi>?a- zI+d9m&JHVN$8h(*O>&OH2%A@V3vdC6ErgV0+v@EqC|cDNzXJZAO8((f{W`IJNSYHd z>~*ctX$S9UzgI;4U^ZIAr==T9Fdvm&XREPt*SJZhQ~9tr7xKdeMC9km9#8QL5HPT7 z8ToE`{?OE-VlI!h>DDhI1U&|L5Dq?0;v2WGDz?FSmO_(_ga%99aAX0;bdu>4l*;k* zM+-b(ce@~GeF0mymjlK}FFKJSx3=%-w(rU^Z$_`$$akzPA}7}ehRZ>ixJgiu27SE~loxomK8A_TYNne>DA8N3PuQYz3ZcV?U!qNA*xZOmwPpU@7{3;HsI(h+h zBsa8iJGG6ee<_H_i9))6T*9$B2DL}&b&~bvPWalY5WL?g@r4T~eCcI;kb3gMyH)C! zF9Vj?tJIkt0hS3truNf8#8B{iGZk`T5pDO5E?(Rp`#=}CiB`BR{9nQmtcPAeuIsd; zX|YZS-;e>`t9AY)%ntf0XE*;4P6Of6{m!T7h62GNB#YveR+9L{dZH|l+c)`Tzh^~7 zFJR8I#Id}bY$Jz1P*=bew^`TgH*zJJsJ64#LUt}*KH61V;`01A1gfSMxdy3(BrWL} z>Gm_^cRn<4W{47I*JWj#@t#8V(%0xQyE6sh66d0tpB?oj`XbQ;eC*Z$7j%Aa0f+V2 zP1K_@&P($9A0w%=@?X%Rdj{`xD*2qrPGizy9KtJY!-_ct=x}o=1ap_FOM*?reTUO3 zK5w_bVt_#2`gO&eBok;W>L=MUBga7;Oxxuv_g8V7t8@?(-wfCe^VF3~=I|0`4y$Bn zgWsOweG~ZlfIN%3KftfGf3!@KzjWHgb#cX<%jXX523yIn67L2Iq<%qPDeF7hPpQn? zgV*S6oWUUBkem|5v^!3wGxZtEm7+ro8s_WPU}Ww>{c>r}+ziz5t*s5e-{qn@rIGGI zxHka7Jb&A6CEg%v>d}}AUiFq+A4h5ul{gZlLSeB6g?km|*hNqlSnnlf(tdn%E#}#@ zFSLAc4>|OXdG7c9WF6_E$sPh>nGlXF@xoAE^_%-Mv%Q1zwIlH(ptxIg?S;Cov>-j( zJ4ojfy#;3Xo!;-eZ*R9w7A&t(Aj|Lw1-kM=ahjiv?cZ^xvl74)Cv``aMXCaLrw!PJ z67WdE#%)a`d+T0ANZF_)Z#U^_2qPw?PIk?gTL;C`=uo>n?>9PUbPX>n>F{^yW|Vk^ zhfwi`R#w%5h`wxg!zI*!unb2>yrdbxW6(SLE-r50k<0r0=tQA~DmLEnI3!x(%xMBGVC3_sUec{Fpm=ZasxKVP-WSWSKvzfrHf}J0-UswD}^ZI!}hpBORK4t ziuS|r0b@tOM$=!842?C7Mq;g!P2Y9&idL?s*o4HVjA@ZAy+XP}Y8_VRTaiAK2}W$+ z39yK|qPsIn3}lxoj~{5_zSQrX54-10t0K2D!+_L@TZ%wOm+(}h`PA2kjg1wdrh(6=uMr3! zDUQZqJ-#1&m=kuo>`#x!vuP2i=;E@z5xZZM6%Ib#EMGWzb+5VldO$GgmL<2`8 zxbcY5UzuLsj>^txOypHqC_q{)NIP78*1A7-NPlUW)s7Dc{C$(ap6R{!=D zpk$*R{JR+Oh?!MN|7~<)?QsN`$o#kZ)P}7O+BS9xX-jLh?S(RK6f0q2`Cnf9>y_YG zZLft~#v#8#jJRPB1Y7BY^#I@Zd*7_#Jy`{dYfE%{N0@P~BpPA*d9&}^k>yS*Hb=yy z@zUOx8ZjU6285ZaGiBqimE9SwFm5y65tOiLEgPRr+Mn}tm340Os1Db2l;tVez}9#95XKTj zZ;NEl*Uw&b-9N1|vFr=Nqbrzd?q%ODJdu}L*!A7@UDVcBj$RioaPVO&*)%R7dHVm5 zke46pyX$yWS`zrZ#bfR?3ce3tQiowk{y*3psC_11t+JGbvUnRZrtY*2rrn~>dOkjh zQTN|JK7n7IuMWN6=tTck{afuUcap@0ddIeUGH;tHm;RL!DmrR8M}J{g*qjvuq!&^Y zHD3v|emVKJMD@-xKfv_hE@1S1GDU)c)SRw>2X90_gYL@JV{elUBG$qQF;q;CLu`;$ z8=37F*7x4p)z^hU!1caK%jf6eQ9~puP~mOAI>$X+6mWndi?db8Qib7gLS6YGs<)1X zQI}LToQ9kbC2@{gV83l>XeSMZhQs97FeIRa_iq&GOTj*|6~pc;!r!%x`>>E)Lz9=q z=B@-2j^I_$a&mD4fkZdi55ydua0uY-9+T(%sQ>wi0K*(tQS^+@4UcfMv2#U|vorQv z4@)k)&V_hYz2F4t)HmG^hETM(@O{L>aNX~h9rfO5=`7|C4%hG}C-ZMz4U2BI16KrY zdbARYgLnBVMI7OaVlnnp6C$Er8(S{+@F}h|#EyMf1%1`=NK4FJhrS=#)gR{MS&ZF% zalXgs{XH|B8S>L87=}h2_46|Y#W(G>4l44R)O6I-$3zBNoY5OHCo|7n4}*2pSG(j& z`4i#0gk165Ke$Wh27BvH1au3Ei_0q`OEoI>*C6aUw6&oE?c8KPm5%Eh<1R)$dD&CO z_62b2S_In#pPW6|55~-WoPY8?tyb|u*7X@Ko!xFk6i-;V!)S{(*9tmbIaxaf=5Lm2 zB~6#lF!LR4(>kW@QgOBPp@s^3fT>~Cp&B7)$fZB0LMe)~AwW-EZ`M9lO<$G{`)++) z_`GkuY7_F1nG@|T*q*IdsNV1?&vGnv^l!Ue#COgNES#vX6@TtB{_=MW7j5n@{tRe3 zc4*0;H8X_S+?nDH$g}9^ut##!djUJ@7=fbM9pA3ju8v}FRU#MqT6&f`oE$-eyKutV zLKU?s1ut@{H*q>yzgU;fq-FYEtjAHcAE@v&>QuxITg&tId>Mqzx1+7&}|l1{pV@R?0U$jFF{gp}05iD2NSQemPLKlA_M$5{BDv=mLqz76o; z0T5OA5fJf!v8Sb_-2{8U5I}OWnt(?7FJ@G}`rqBXJ?5qckUoS?DMAL6kP#6PJwTOG zXWWlC=aRx8myK06;4VpkZo~sPFjAA zn=Jx#GAa>2$31!nvhW((@v3&bObgWzfj~>l{P*BAiY(doV(jVkXnsQqs-)BCI5ag( zv?436IXt|KQ5s-ef&2=bHUaBGN|nFmf;ra;;>yo!(jX>sft@BcVYB==2A8PFezA;E zBi$*4A94pjxYNcw%y`+QNGwRu5aW6xlF)O$9Lj)s?%eU+OZErgv8;D5_8Cf8LRjBFwvTp*J3B05wqeF-_d^QS0MMK5~^vF!}Bk? z7wvdrh#!=KjMP*O#cjjJ$b9WF^+k*tU0AU)WZF{rQoN};EB=-%XFl&0HC!;>P1=He zgzjmqVEeLUIJ**=>Tx_4Wd(h`;@5A8k}3~IE+*cqnQ5`uOvLqv!(iX&Z%4X`i2uuJ z6EkzgzHUlgkQvu%vTS?15OF@@dnWtfcgW%+Y$;%FelazvK!9ej(X3ymke;tjE)I*% zsLsP}qSnjrZp= zhgx#kf0NKZoL@b$@#m-BZq>mV@ng=hj@=$yazdbkjh%Dzs1PWs*kuF{(CqU6?s4@(R{P#JJNtDD(ocdyDZ5q z$wje$IDeX<&qGOdXRXI9#;xTvx{r|6r|KQKoD_ATq&A`=aFDB-M&?>J zvEIMU9IT$ae{#@Tz53Cj`r%?3_G-eux3$U*!GdN z>lOQ*oWBnYAnqC(=`{fOVjw2H1>&l6;8EBHfaMTi6|=6}L3dBW=g&XC2)%gm&2lNl z&@U1yV0LtLgeG3YKtQ$~YjyXN7mIEaIuP$g+uQ^7yb*{f1rqi?5Hp&PloSF8CSssh z=TOY#3j|?!D}+jq!?u!y1F2t`(w6K{H2RlIOtqazn(%f_m1MCE)*3KQ zDVNa#YI@xp;$9n4t#j>mYN@O$M=$1L{SdjOweq_|Vkw2!rJ_VPiDtj>wjL@eQ+qWN|ofOEEd3P61&!M5>;>eJ}XP;9^0+pkv)4n-pr$%xHhj zF#Fzk09?4_3h(AMn=qCMde4MsiKN!irW+Gc#^>GH@V)Sup=TcygW$sV+#llFfpd9M z)B3o--&-Q1Jf}|RYM#LG5pJ0$>*kG(Q`o?zvqxxe796*oi;2Z zeUr9?7S!8H7PNp4U=MAd4Bi=e2FRXa5I(M$pyy zprIMN_Ai-;;n!AQ0d~1~Q=eDfjLNjC6bz-hm1rO^DMIpyEqX8LsYRK*O zEfraVB%@%P6$Ty#+VHlIE%&F5`xF~}D?kbfe~S0i5RA8)3MP#jFZk|;p`>scqHG#4 zM65PEm6k++9lEvQ8z<0!LP;SSaeAScC%&C+g}a|+kv|YJD<47xMr)lQVe>7xY>NtM)@2KE!( zi3Km;`!fR{w@17+Tb_4~Yt8}ASwq;zs?dc(M-5FgM>N_WlcdOlk~lRX=EtN19(SqI z^GFO-sfgYP*hs$Ja<}s2U5?|FUAFrwH`}Vj2(xGIr!K-JKz85 z)a76;povtX(@P#9izE1-MWD#=uGPf7b%x!3P!~biGfFnyE+FlOcA+$Um>2?#8`i!u z^U=EZ>0ksljUL6{-@Btw8pOoZ)N!*1hL+(1j@kcyWn;2Jiyf+?1T}bXJX(Bn{egL< zw5+Tb=tF~0aOois1CK2PXnOEz#|0)-ZUF6$sOz*w1k7$2rNp0g&5UOO`_#_i2*FZD zNA(jBi2quLqMlu@r40*T$=h2~3zD_%ZRkV%?I(TeJLY|&jKu0D6uoF-^<3=Cuhggs zE~=y$bT?P#|~(T`czyNbRP4P9JT;uq13Zz(6RMTN|bjYS=IdXJoZ_HWVh zAM5%Rs~m@hG6<0Qvu`J?rMwwu%j>otmW$eszLKVJ*$=+4w>M#F4Mw}#=VT?462?6# zl=)xk*1tM>1Caccrr%-xGr6hYw{(u>CoMO6A1SaydBbA|A4zX-ltbUBn`%TK?o0Iu zQ&G(l1pt=N=-mlQ64$(`w=8y8<4(@uP`=OXO8qVQj(X8UnXY}Mz{G;JkMSwbQrj1Y z8`Eduw8fi7o=o3aqSuNBiFVJj$x{<5b$_Uo=P^nCbU@DwBtTO(K>Fgk-z))o?$O1o z$`Bxv^Jvl(5c#qFh=OFZS|>8ioQXkU%p1&}`}AbFg|i=DFlErW>`b|4^=0JxzDr_?2 zbD_21f0CSDvjd044xe_FCP$2yCo^o|^I?;^fE~WcD!#O<^fW2&83kVc$XE!}6&6Fr zkCVb~I0Ulug+-pgQ>&t{AItCjGXiK(L8-(JP%%qNN+S5(ZXr4T0rYyU*(d?{992A0 znCH)*LqGv6EW`+;B2wJ!#Rs_%u$US3I8`$8g(1P4e4&k!V{c^rW+GNF7^Ca=Af-vE z8tt;uLxY4#XjT4tc$r+&2LDvZh+QAi=Db-hGfE%~LpX&5Dl^}r&)ia^e2 z(d}EzqW?TY2y>=lW7DdD9Jhxw#x93=@re$$$N7y^zma#Wh=>=&z0x2~Z{={2m4|{x zZB8<;`xynoG6BKeb)bHW@0F9#<>3pJ#dlmDwAIb6R)y#FG^*2UaqQ$Yp6oX2ksA0R zT=x0C_6Ljjt_Fh*|Hk-Q>n$mSA4N#+b z;8LQsP}<&o4Iv%0HCSy$Va2@aEl(#oAConBE`yBGA1h!zabRrMO@-XRzxveRPxy`Q zq_f-3AY3kE=4?Xv-o)abYl+{E5*vYL=TTVZfT&}2HN?b~fpD}aD=IZTe%Kc>8D1fztk z2Hu8b*JJs4SC6#D70hmFvvKMjxcZiHa_f^+cLuiA-A^znOJV2SX65A*i9 z9_RZYn*Si?XQCjX8>nOl2M56*2zw|gO9Q**5A)d!8mO61Et1B{@_lK?=QU7}z(a@x-706=KyK$2h@BM#o!2oS({p9_ z8ni>cwycaSa?Wt-V5O;h#qLorV82nc6ad6XEe#j3|dh9|}YY=EB z%S?cas7n6 zaYVBl7br85rm(3fhRE8s*X=q7MkYNo=D*Hg$;$kGn-NDFDw*3mH-)V)|M%bz4@u%K02IwFxfed*Xk+|+yO8QJ zxSKtY5+^&>=ktIFxjh!AusfVqGgdN?Iiy%1kpA}8Y;-O1+RTE4M(ve@cS^R$UJu)B zhQgu))ic~z$&$mW675?hTBh@-nGnWa~+Gg8NT!!(k*|wym1{g1m7U<_pD+Z^nIG0<3*8Fvy%llNNQkLfTkl6#*icxPI`xpiR$= z{(d0agPlC?or>*bqXJo+RKk{!W^nh0QdaAvT)IX-<%^&-onB{EYTLWR?v*-0pr7~5 z)DRl#DfMwe4B(<;W3O(t*PX`D;t&uJw1I;G2NxHb%3N{H1lH=otHm`fV#lu7x z&)zMV6>xEI6GI*ivDLGk;6R_AnQ`16Kmn!-Nh)t)5Aiz27tznO1$(dj&r@-jq~1s8Zh^ln2V?Yut~q4F0eR=LsY%5r-_Dczeml*% z#y?!+iffM_y6NjHoxgs~hj6CyG3}h=Nc7p^A60d|4tTwCg63C}@OdEmeecC_T$WKZ zLUM94K$R&B|29~Wk_9l#|Ba!VI#&!XuD(VHDB*buGUl|!N74%)rrn&GO00FbA@C%x zu0sDsbafi1c&+t5n_FZ)Co%l3^pXRWOu z2iNhuD}S5q%k?_pr=R<;FXL68ICR(?G2-PN{`yu7#WKA!8T&SP?`~qek+1y(<(U*+AqGEZUGXR#7D%IjA;zJKZu+~Qq-g#s#~=bNL4@x7G1ZAutcFg6fH zX88(2An20DUUHu&HvZn2yY*UoHx(>d^-ilv*Ig#9e*A0FZEQqJqO$b721JglMlRSg zgoow5b!hRx>`vfQw3m@T-(Rm8Y(>4gk51wVd|mpfXJea1hGdh@_~3fk)ra}uTs6v!|&KW`cR zq0_5mB-F2lZEZ}EtgADX;G+2TDHz}@m(dE4H#2rw!Kpn)TsjB0G8&MvRzl;lk7x|9 zlU;KmwI8YYKC}-*#Nc-gbmd}{>{F96a4AF;yC*dq~KMHrw;mvQ)V;8{7j9j(LL-gq{V$V?4a_d%yCz4Ruqj zzF4W*w6R6?RCd9iWd|?Rm$#FU1_Txq=AtiKMH=jx$LHqy26U!g@9=BCBe-a`pN*Lo z)>Da!#Vt1VHzWR*PNLs+N)7W~x~aD;{e=xPlZ1*Se;?iH9{A&15aNVI#~x5OUy|5a zSU`pyYqc}2Bk%6~Dv`9hHIV*Mo38Ht)o%nTZ zQ8;N&l5>%$NVJdLlO2RTFWTQi0e?QA=hER;y-I(&vf#;ZT+nyVD=I?x`>-c!#v2HN z2eyGus26O138|?OKs0Q8bub6*XFBP;F#z^YW8h%fzV3VWFe0<3qX2H1dT zfNF6pa`LK{G48;^ain48Tc#XB6FB0bGV^2#IoG!5mo_OcBGBsO5Ux$1;9cXI{@!vucJY0@xH|5>m7{XksX^p5%Fz5DPOCfv;l z2S;8O<3T0sNN_TF)nfAo>O^#WjeX9Eqj#LV{HXXy9EQag?ON(nJOVvl3746ynu^!K zQHz*oip_iX&n|v(r-FBquAob=)`z%)$};SnY@s0OwwvDKXhMhepSfA7>5&qlcj@c( zorPFqntQ!wFmOc=LRa$(AT|=_nxWl*A$xdrd6Z9Y5H09&8z%vWAg?}UPsSS6>ioJ# zC70=&WzU9Mak2SULqmth(@v;Q7~s1=n|K1qT93O}IpOIU&JK%VFN5?8b$` z!6&-15BtGO!W`V(qg6X^pB6Qr%Ioq*AYJ$gd&C?{FlF+NJZbcUvh z9_!&-6zOifEO#8TnrO>tNP5`UV9Er%00e$G^_p0>yl-s#^szS?VonxBcZ2jTb2%vB zB^a#GbKiOlx)_H56GDkqxsmxhn@;tR$m8P^ z&#>ZihdnZ(2LFuhuIg%I*u2J%{A5*H6JnJA5+a3IlcpDA6BCySG_9w%klZj}6$1OH zmBDCG85KSKPdT?yd%cDrn$dbFvRMcAC?OEQ-EvqBB8Gr~RBCC%($dlc)Kfys3H1CQ zue5880_7cSphcHuEqV?jO+%skiW#wifmmU#Q9;C+4XwDtMebr&;`uAK?lzOnO2Z)% zDUV0wR+N`)mX_iYI?~9W9A#cbPo@&hYBY-xxmC$TKy2ve*=+SGH*^7bK11h`Wl^3u zDE>6H48S*HPDe2|N-Eu~EM(&m-b+>TQV2~5MOYC%H6J~q8S6uUrHMA*woTLd&@{Ea zuSFqKiJwP94UO8)whOtn$?&+g8Sd_lhE}eBSVK?kG*5y;VBe!J;`5=m=iXMsY`2&D zFzQ^qTk6ijVf@HZ^c4u*D+l&P5)vj3n$kIvQUj3G`8>2P5l{F*j4Qr%?Px$3)YfVW zU5`4c_yF2nJa*t$F+Szt@%`_?k2`fNH+bJG@&EGI6Z)={Y$v~IAH0gcp**;_>4HA> zt-r;-^f=!)Z&j~9W+E0cNJ`VVu6QZzxdS@pN!Ul|f3WqHaZz<&xI+j?BVAGkf~1sy zfTRdgGDDYimxPocDIL-vT|;*Zf=DAELn9#F-F5fyzVH9u5BGfVC&rmG`|Q2qS!+Gd zvbAo=CDqE77e`}|yZbSUl=A7?mzfz?m-w$ zBnBMw02tVjut!0)iiv!O>TBl-d;B;<%5|Hv{QOS~Wx2ggYAM$m?e1($AOvz@$--4! zsguC1QGEL6SeJkwW3Hy;1<|^cpYuC)bl}-Cx?g2{f5p&Lh0hTE&7nGsa|$CWY%eN* zVH2%qt$a^c53`wxX(;E-Q-Db1q;JpX%3#aR+*i`Fw@Ol)qdy7Ngg$_1(-`I9-F8)5b%pZ|2s`d1uDJhg5 zd!Va7qh{>#Jfe5q?-iZ=-=se|fy@He&IN=KC$vR%I-VcO|2pJtF`r-%Ua_8CrPWd2jU1Tf|8H^)SwK))NlJO$Tk+JSKgqwe z<(dj9i}2^;jpxh-=t-plv=YcMqo;pmPlrlLj2~D%rllC1)Ey7ZWcCZ(D13V?Bc-`w zXdI@~T3=vD9udyZ^Bk{$-Qo&N_!#8lpBR0~#$Tvk?@c${RJdOf9dWlZPW5x|0Ot4% zVPf=o#C2x2u3lBrpRJN(XIk0lk!4o7vc_9dVz%0r($&*fpePbjW5#cOC2hM7InNA8 zZ`4jRA*TREv0mBkM(xit4=*pw)Gt&YTA<2wSSX|1NY`8>B>8gh21+IXD-gS{5{b!o!^lCy?0Su8F?`m(_jI06W282 z+1?q#gdnqv&$FIBz}7>meRAzJ{$B2rf__CODJfgF9v4ma5ge#H)Xy!W*2qO-+gI^# zUiY)TG@q<8qdtgRYL0f^-ivE8|xiKFfiO-|!WiSc$g*3!)lN8+*iI`WaQZ%pb~w;~>Tlgmdpj>sP&y=@bmNf5Q)P zKHx2WA~04i2W|*h3+B#HnC$eI(U6s>r}v%EkOezT6OnX{LS!3ZJh2&sUPQhAO|uhs z4mu_7JK7&0v7#wjTrhtFV~tA@M)i(xY|e_QK>Ib;x2fNv=H-6u_xs$b<~)o?${>`0 zb`!#>`gd~X9r69*r>vv4@^}%F!4dbEz$jL_3ymnRxv4Bsk7_jgB~-Co(IfLzXfh1CqaW67 z)lBV9n5NaIYOsPW0^Mdth*}3;cn&WgnI~o@>|ISPR4}$fb{+pv?(N}GexoCyjfwjld_>PNUVZnsmqxfna-F;~ z58p&4-PQNS!*F#_sLb~3^%u7FyRUZn{u&qdRdk`M>HE8%NQw(Z>o>pz*7J5 z(yWjb&=P%}Jf>c^+Ut%uti$iLxv7NbcMb9&n=0ap20-=n?|zEjFlGBS=qD5@I|9P#acT!?1;%u}>K7PqGMcKJ^a2x8AirS9KsdB7afo>Fxg)b=;kV((c6wHDBcf*6!zLnsl?JhPWH+pOaeb-ZZ_qQ<0JFi{j zT+E*h>_qUG4~OPjD@4a#IO^!;uB1GeePk{-kh=1LhSp}v@)fa;vAoyB8ybaDtE)%B zj9wzeI1Iu@4#LxuT%=yQ*aXvWiO{v#x;lqKd19_U!up~PV6OU+Xj&)h#f^5OcKh`# z)v%CC>Ee@=msyB+1AkIsCm|5s@oz)fZaF11MEjy6@E7dzi5gk22frzoyViZs3b}X< zzdKJef1O)wp(|h3C{d*IzfoUS^mP7hO!%UnP|!Id5mY z2*7T1y%B4gY_<)PkkZ;LRtY&iavb+WZA~rT+l9J9Abh#n0nEYb39fx*A1joNYe$5a z9E3--$-S53;|+et$_038mFVzY%1_Vy*lU?J{-p03mPbZD|Lz)ZlN-L=%rv=wxiCp( zaOh~Zfrg%?rGDLvIa1u_rrt4YE2fr%3&G@8!J9WdXw?t3T}bJHF$VtrWp6b1N)mbT z@g%e=&q_N1+lL$kEdxVNg$9NO1`Gpf{Gvq4nwr6Y#Sl4944U)#3PfmLxbXnO3$RV~ z=pW`(@S44TjlKLUEuyT9Gp0VWSgTeV%zBo4UN{5lprNwnrY0KXWEDhzf3e%?`rFKm zCX~-DjG%L*(HD0r+$0xKFP_i$c#O@QOkJbEwLY6rJe?g2VL~qUl+D*w!sQ<9KCJ~7 zO*09$^>0kD8BW zJFXgS(p1+!*o$PO-!dhs+k~4vufLprB7t@2x)F|VZb&n)lhZ~72}PCpjV<$0!rw4E z!tiFa8gJCoP7;Ucne$Z`LB#g6&5NoVrBh~_3n6K^Ma{Y^t1xHPUo@(MhaIj6FIAgM%x5K{(Tg&?y)z zs%hmb5#fQ25fI{rnxiTvZZ~)47LqdrGyV~b)E~cT+4h`gD4f7#_&n{OKG^ImRef;% zM*4PB94Y#)t(+knKmd7(Us@WSX0?Sspko@_OaNna$a&(f?(Ua>M_@2ZhM>xNiY$zZ zS6o^e1Ifhmx-QC6uv%PKzdeAlRgSaYfmg|R7{L-EUQ{lzv1P2 zVd-G4tsb(91>pRbs;VT{mj{Xu9zA*ow4GOf=S93H2%3zf^q!*X8P(~psrp~(DF@!*zGPnJ52>)V! zmoU(rU$ymeILm--?TLldjK@ZYR%MEV{g|aE-=s4yty7{!?Yodf`E7?yjs5SDh%W@v z%VM&D(pp#JWLs{w`DeL!{z-h6cGmZum0xo+UM?t*>n5WSh3fi+!U2nW+lV968Qqcu z131B%OtFSdIjSN&vgHxOG%Y?pMoRym(WJG99jmLBp1CO^^iE&4iH4MVmY9E=m9ZXC zRb>X{tSF#BgpD0SZ8RD`bV3th0ej+ zD2QKOoGcuT>2-qkqtAewv&sU2t><<2FjpaUeQ;-XW|cKwlGUH`yvC29%4}o!jJN+2 z3o8xB(vmpYIzfNOdoxFD+yiQq^n7JV1*-`MFMmV$GXbGphoh_Y4_!+^yGK&e_CzJ* zWa2}kO?Zh zZ7tupM>vvj_@lmH$|?sNY?o>x*0O9GMQ!@PI0`fJ(#vHi4zv0&wOjS~7ag<1S zL3T+=4qJL*ZgZY2PT)=0P;I-X=PbT9@44OUl#h@<&7SE-2jVMf0!H?;wRz#`imIUx zPAtRdu_%65&@R(VqE+X2*pmOsa8k-Qa9*CV@26FfwUiBM)t5XwWhlEnL? zUI^NJkzCVuLwk+sZDV7CXx$50z{HDB+nOf6q7pi8CkgXpz2SSxzrM`HQ$lHU@5N6^ zh3|3jTXSR-h(RpCuG>Tbx)T!cZ#E2C2t%xYTYu(K zq+OA%k#6t5fY&9?+Qt%#vF_T- z6bH^2lo6^nuRe=wK1?ISt&%Zp5g8ect4BbPt?CiMq*di#nrCPNr*zLZ7P2bJH@keX z#*=$C8%{gMjs_9*lVU7Mic71Whui%k^K^f!)zXp;-ppp?td4}hlT|ItN%-nW?dVgM z_qwfkt1PXDj$0<%@RW;7$z%om#t6HtICvrgJ)H7SY(`KP3yqG>8~2aDVRhi0Y`#%^ z>Gs(WK2cQeqMJ`L4`<6_MU{b^K3wB*lai)B#*5wEX`;tkC8rJ4V^87p3-f;OW0*Xr zR0WQwiWhrLp7zU^A1G0UVi0NSa^;etRMiH>PwUdH_LVra-H2!MmV$xh?ApAu1Ehz=yw)#DRjNme!u!mi z{w$;rF?Osj7yS_Bj|7UY-A*l{m8ed1>Sg09PzIUYbUrKR5nLIU{M=&21MwWf6}li1 z<%YB+iA3|$y`t@%;COi(ZaGzK6f%(q&p=;WQHzy`Vsi=l%hszmTBwkoYO0v$v9ll- zeIvDcCf~WOB(Te?Z#%fzm-s^xKd01gWS3X|b~}#BI@B10c!}>u3y>C!mg+?SG2vJC zCyN-<C%X*;;Cy*C`}0kD5f&!sAoI`3i$EFl{3x-Y^~xcQ=$ltvmOU|~)u)QY z!@a0tD}@9Yp0Pmp(bcKspO_elRuKfa@Pkbam%jrHDI{Rw7T~HcJI^rZIvb_yviHrd zmJSje&1xr{!pUU-*dD6G!T!Uqh4zGspD$Bm=~u<|@j;Kx zJ6WkNs_ery!{Yv_sxkjYjDAV{w_NSYH!i+NqaM_CPgWBSx;%%rH8_zB5TEYg3|1J+ zJ&6-%|JzgJ$BW?Mr072TP^TG5;AFJatnHVIv+x0{VrX^ zv%l<(mI{V)7O$!RpPw@vGPgDF17pY#NR0IG@sn?H`+4_h2n`@0>H5HyoC5-z!M;yifG8|f4FqyKo7)jVY5tJ2oQLB#MBBAmD zCrq2|X=|KiymBVYU?_2c8MYxlDr+s&{;R#bxO?Wg*&_#ZpJ)I0xi0)@^zcfP{t8L~ zi)u|HQ&Yp8Kh3XRy=n#Ya`uO-%3#aADyMNk@mw0tpN*HO+EuZGE=51=N0lg}jn5#J z(mqk6mv70++=?Syj*uh$ULa4{2mks?@aL9tDLGlX#j~Mgxtc8f&by1#I|j}kp2t&X zk|}*XoKw2lES>qhy$=g3e%EUAt@ObXq{MSO$1i=s?f4hNn+u}8gfv=Vn zxY(#FWkAA*PvRul_mv8dPr06<-{5!ud=t?_tX@*04TozemCZX;)T&bo zz#?V*=iceEL+D3E3BQtwQ$`2a%OO|;2hulBnoT;hmzPkmFb6qNWx`+~iBo){$@Vf5| ztrBrU_v|t+JYNDaZ^n=Xt#g`Wt)Quzq(61FO#gm~fPG5z+vhDr`3dv+I3JBZdE0Sx z_SuazwwSZhP(GvF*r5))>4CP;MeGNPEyuD2XB z^pTx2<0(|$ajW9^5zHLmljTVH%TAuSm+gMcGAqBf%@-Y(v!v`enCj}_!lrUeWDGP& zGJJkBI3LRvWKocYIJmS=R!JKdCWLT#iA`3(p-}brx)$k;=N=8WoZR5kr_;v8!N@I` zx&8~fW$2@o7aVV+>`-AD|%<;V$#!K?P58%Aa@f#vLTP4M>lJ8Jet2S%o zU!a%5_xs;FBVJn-X$@9`ahR$W=cJ;VrO7jU(rG*ZTMnVX!SnWp=J3(d;F@62?-qCru zb6SZR#|Hek+^u#}Uc12wi3WGYR>*d|6hS*YV0~`ytrvmi>mR;ZM?VE+Xh-yHU|Az`XTP6W(DRvbPRCDLn-DzF>Y>kWMTa79CFZF_f6^aZVes z;n?V@k8(R_`J>D>`8Rx3ARugnQ?<{I=UVB}?PDrK**TsAFp3(`op8t}a(5u)7@p+w zG8cCF3eWJ^k)WgWe-gK2WTlPYw*~n*f%Q3F`?b-p9{y2a&r8ccLP|&$8D`Rbdq-i|%GQ!gJ7fk}Y0PsK1l4<55R2}S*N2?Z@{C@rnSy{Pa?7H}j@1u$5h z$Y1x+>-KS4{W?D3)DAcDH}w6BeDb1mN9cumW`Kg;JXSH5a-n}Ked?_TdK&ovNp}6Y z|4Or%4Deu#H@a|)7HeVZ>FG6{{tZ!i>aQ(7VwMrT#!92})m97fl$qjj#jyQ?$rEhG zPxCSA^*hn(CUt$SLuZGi;l>UqA9T6VXL+n!ifpUbpL7v!udEy%$0lq3OrM)dER3y)}(M^Z<_BO*G16#HP2 z#xvkXOLA4Rru$iGP|`K@jU!zjPUBb~>0!0zsLaR~f=!Ih+?>) zu)TjU*ND0%2XmdpArlQ1e%1UoO&H*1_sxdL8dajBV5&jPq8|@KKF4qVD65oEV00*f zC_V%!=NC)fE)}Y)lBO@_bqN{ymt81Ya6b6~U?Vz69huffkAZ&1LSgB*HUJqYX={h( z=l{c04K%X>0rkfE((~{@`?-7+;h7WP>|6W4C>Y)JHi6eNuAlgZI}z8%$7l5)T33R4 zYD{HS{!C3@hYwfzz)mH&ekT;LkeXA0kIfg`67ul`;x0SDo8Fle%jHjnaZ zwh?vEO#V}MfH-@Mfqu$?Na*}#WK2+s3?aIo`^B7_Jwt^NnTg#Qb5ub=(W}oK1v!cZSbzb)VO!xy&h!`3+nIFdwCoo5S|jV0&K__4i0^n z@$Tjf1|(8L+i!oSfp~@M+J{|3Ph#QrDId4VkPDbt^!B?1oHA5qkMBNYr6~|`G`>)8 z;PtpZTRGhZ&}`qcUvq#^mj41$tXb?l)boIm|9nDO>vRzNcDE2cz@`aXT3i$Z-`KBA za5FA@lmb;KhZ-~d&6O{1d?|)%$+PEW`wo3Rg@lhS8M;1H^dyF|(!4F~GitCobdo5t zd%)c8+9kzM#A=YWVe1X8aZ4wx0^FVA`i9`qP-3*THGobI+06gH>+X?E9|_F5l#0;VxAfNcub zD5Kp!ux{xbvQQ1^<*#t0XO=36mpoE+Dt(xlG`G<$0y?{y+1Ue__gyO4!0$_RQF-Y- z0FmkHM;R^Z#<}1Y7Z6ldlst`A74W*;>m-yS!XamueEAXu8ykBU(632k4b1^{lIXwe z-EU$oSw&Gc21vszg2Z^93>F_Hwa`%MNLvi6@O~3R(?DNA7ffzZx%*M^A%q@b;-4i7 z!g;|3?9NSTuUw%pO}O;wsW>6JDSKE@{-3VCKJ;|=U(^&7L1>uxNccenEzC*eDybYjDuI(Fp&PR8sPFt?A(Cdv16}vI6XgWiq8S}&7CSj(6 z($B!LtmrlhF!T=W{`RR!e0pSc(p}Q+_K&4MW~RHJYzLGnmr4ZUn~fGx0^}s9sY#f; zNP+58=Duf2{^+DY^sO(nF^}hF zzOIFOS3&Ov^b#R0NCg})146ukvQEgw(UhE=9NTtjJV3<%jmH=*T?4*8fa;kDpo&$d z6Sgn)-Sv!mh=WnGi3+K*nCXM|GOEJB-&rX8FnEvDsv;{=a_yo6qn2@C)cA68fq z=R~e;-eif6Q02m%mtybUWvccvkP}vGexXJTx;591y+_le#M;wOM88pHeB*Xr4b*ODZ?g1m5{M-3DD-eaie48teq^}vw zeM1e-7zqgIBw`-D0pyBwb#9--mkhM$yD0z3(Toxw(CLg>R9BxzP?0GYzpNi}f9qmf zz=DOEbZ7ljhuj&z9T&$~S%goPC1tHV;+OTvP0P9$BVz?S&3%aX(=-;JtWxDU^-P}x zoz)R1A|k|NV5>VbBm$8vW;7xm6iBYo=;-KOFnWp!5MHq?=`ihP$-sM|e>%(CqHecs zk-31UZr4>luTnRtzs=k%je;O8&yhO{suH@8xCSgNy@a+{q$Rm?NTDDQ%2UBUz?*tU z0c^pC+%;Z&q3|~jDth;nM^#PC02``kr9HVj0CBtb~Ek@XNjFE*S|Za zS0E)YJ8L){9E|rTAOSHmQSEkpz=Q`Ge85R~|44rePveX2(q~&wR>X!>EP~&v;I*Vt z+HMzT`;t{w4ID&yzNie-qkRP%VEggsa=LM38~j^;i<5;p#Kpx00LiZ(ot+Yoey9GJ z1mZxTW6a#y3YZGccR#QS06M}5=T)MVnmo#EfOZWsGm~OWP!`2eW`rL|I*0+-yDNM7WGSN zY8VKAkwl(gE-=R6BAixe7m&my((df+L;%2*@#)h(D>{VXQB#|8=?A5J9GEy%>f#PP zh)DKvK@5Og)L{X;(7Wi|ju2Iha3U0xNW8Vx=xTr;CUY0qjWpQM{uAAeKDbL04`@OG z(vF_8936Rh|ZFOOBYObGzBXi!gl*-(~^K5hlq>p*6yyzXOP0@ouw4gfGy^I7T7jSuGYFvlZGv1KsA zw^=7k9l{MQlru8g%FFd2=r7Bhc^MhG@6B7Q&a1<)H6jMO*w{av9ynKCfo}<+r%)5% z{`YfOA5acD?k~z7ZqH0@F9-x@o02retvQg z?3qk~(_*%F3M76OmzQ6clQ=LaTTiM~i>QL;E)k+LWBxZaVV75?9u6xnkZA~caWLdJ zcBf(Dnv$i{fnlicD{wouwmfdTzNbTX!YCgiv5GFi@9p_H@ZkTF(nqrXS@H4Y!6702 zH3Sr#5ZiBOxf93H(f62RU*CSS;q;>W(Ay6LB|4eJpNDS_5p5o!nw2yysiveQtdwfp z_g~4C*MVHj3J2yUp*s|ysW)LpBLFMje+<_3>pWK2sM7{Gi{Pq#Bs)hzOzaEh#m7A_ zkqr^V#7rG3UH+2Q7%W+82|{WD7FC*k(=|kyu~&@V+j90Ew`3dF*0a^CF zQ1I*!Lg3((+gty!SWjaf*a3MimJthHK$vXk$Y*F{!wv#e%wial9al3rClY7PL;!{n z{?%q$g||3=FrzbS!QaijUBcwPW+kwD`6i5@&u`>(eCg| zH_>W$;-sXcQBM^}@E+aTjy&1u8KSLdj2YDdlI; z@H+`P1Qs0~0IjzyBqiS$a(mI`nfzzQ9%;-mW1-&aD*~_8_O~OX+;&zZJ1-Au0GNjI z6zPEG2MU{3%@>Qii|>$GWY$UUw)Q-8eE06&+&o3K$#Y4Zi#rfHJSzNp`DpeZMtSh9 zw@e8^G3D^S2%c8%%gRArKpi5c*NB2*tPjyk zj|0aDg=!Un|9I5q2@0Ny%SKe}GE`8JdqhLU)>B_4PTzlb$Tl`%s-QVV0Zx5m?AhSY z4ol!(z;*JYf-O2+^X(1`O?NZ2*zT!XA;2GXs;EHv^Q4VgUD%PS(Lo4V9z?pw1}pkPDOuR5Sp0XV;`eJ|0p1GmP*Dl*&0iJ*%X-s+@M&4YGX$JZI;DZ)FBA z8zc4rj!j++>crP~mE9yBtq$CeH)qjW;Qfm$!bKQJ>?dJhuTnBV0F}DPP1)Ae z)L(-ll5p*AMz=%Yl?>d6Ovnl zasWs3;l1PS8P<+BZ#^@I)0_U1E_bg>!+in9Tc;9fdXTVqDZKew*15+mC}>?sZvLI0 zKOu6PwQ1P2+l2t(0@srkN+DAH!3OU1Z3Vv3C!5($!L{Ss- zQhlL@3fx1$G?gcBycSs!gj zXnObsL~#}zw-$zoXlA{w5R`BJUP=QvfIhyL)38S%#-8F4X*eWT@jhmK%FfO|mgUQX zoqHgRzoqX6Dwz^Mvk6wJ8;1M{kXR-FKrAyelM`4u4`A);-mgWI$kDr`kEYA(%`+o` z9o?1$SDw3A6(8h1^a;;9eh3-A05O9O;L_oWtxKP~HI=w88VV#6V#bnN zP-}B9Rc@WR)$TaBDyBmsE*-&r{O>(pD6KNS0TKhRH9^RxT5Pw7 zTmWu9@C+WOfbD7Qos^Ve8QqO8A##zUj1^u8al~{F0+k^Tki^&7&7&%O75)r}QmFs@ z`GX{Ba64M}A0JmYW&fO=J)AB-s{$RUS5oCCsJ4heFm*DBaCRBFgD(z?pa~4rtaj;d z1N$ZeP8t@4jEOQ?n!bDu=zZ6Lc_A75b8`M4JS)V}9DC%Or#muALeefFr46es$KK5} zIF)GMdmafk`;xNqd{-QN5rX_2OwFo}k}pQtq@=)o&;ig{W8#&hssxmkMs9?~tV;7* z1c2{oii!m7#@13RRA&FPLN;%mkrEal<9@e3Qc$9c{XAGR#Nj!bjKfoa0{htagJ_&F z(``!~7SKHw@++lv)UrF_w>$l|?PQNsAvn)2@2CV|+WP-_gYOaLG4DsHzNjsR>SEV+ z*MM#k8*^mU^3d;%$y6yRseaq+2D181UKAV@G#xg~Y&U|{KoA{FRa_V+ar@?pzi+ua{U%w>8 zwT9wxG5y;_(*;Uv4-%6nd-$_bkWaYi)=b?&Kg!X%%w>O#vO2^3QYg1f};@i2UP0kK}DN>V76+QYd{nt z`iB?wgx`88GZy@9<4~y}1|i%)<-B?D9`czSfgu7qHe78!NTFD?16u9_i5Gl62S?}!}3 zW7cQqhpe|RXF5dtQ9JFBeTgmx@O4}ga?mI?=xygy7hkvN&k`kOa7&X@H}I&0=qSDz zc-&SvK(qj;Mcv-#_DTQEAyrFN2keSmZSzPOEUNiRUf%A+T9$5BnuUc0WVqKpRtPH= z|DF40C|j1ey=G&ql!=E&bv-9Crl9KQGa70IIp{wjLS$-kU|A&nd-GW;hv!;@w|hak zJrjT~q*9Out|v;CD2m^17;e}1ZFHVkgy^+^uE(C#;W`56bl4u!LnM(C zP+@%COyW?%Z3bCop{9QAzfkicaChy~(yQL0qOP!DtoG0+3!K2_k(d1AMZ3yDNu_}@ zoYW^v8trk)>)R|K(@lCm;-Nw^A6|s;J{CT$pdXn3M+&I9f_;N~K}<|+e|c(;dtX=` z`F$vBRjsj}(bR4|_fuOQ;0F~{dbVfOrG2Rp*6#TAX2!QmlbxqhV0U%yL?h}&BP%Nl z#3vn(qSx6x=kiHeYCoqe{BxAaUR{?<=BRuIcz+I~T(MKTHjAT~_;&KJ? z;tH@_QoQ)sqoy4Ae*uS7$c2>h4^y0Y!`7|Z^5)wE9u))f%)6r3r%W|AGr<6sdzqt> zxM@3G$Ay7Q;f+)#SnP_5Fx3YlteYXe?&xyU$@ zI5uluRp7pyuqkXp9cm%X2O<-RS_SLWDi-9M$<&$xl8n1QscrRvs~ohb-re053&f{I zO5jCC5(2QE5u}U&+HU_PdJ_twSMWYj!RBZONyGfs*4Dc1_hL}}3bUm8pH&;#sQ^rc zcu(y52KpUW?Qd66mEW*4yaNs`B)oq=9O!cx9Ig(njuumYd7aa-P^!}y0Dv9Cwc#94 z&>=(Gi>DeD=wKT=?DuS7f(krIJ@#`V5rLe0TxDW-^;ly4|E5ZP(R#8-dbK-rec~1tX=vR!$B%y zSo^X*(H--0ZiF2Kla2`_8g^MOX`LE+i>PYAg zV=e3v05Xo6+sfrGMuUQ}DkosPg50<|leNyLp0-ntSZVC)HV(7OLd7ol=sSPYd)!D@t0m})lX>v4~=um7>%smpa`_gNS{5;*S2mT#b zpXl|^7wz#05HE%yd8!WM1j2EJSp45kPy92-G6TA^=PX zDykfV<%Tk#L6b>XmUMi%KlPenYC3UtV1i5fXa8^u;g$dlmoe>r7()Q0DY1 z0y+z7nEH0J+p9x4sU4m-_ixhM+j)SMQY3%Mk3)SVd3(mj;LUK6t|HCy0s?Mu=2Z)M z<(wKNOIl$+x6Q|V77{oYXfID@&Q4p- zDSVRNI3EiUD*4Q^cQ#b8Z9lgl8Lp(H!hbKVeFYBETd=awjPkw+v zC$A{*_7$xY$chZq1KvqoN^bP8U%#qMH%aMv?A-xGS}%BP|ImQbE(Ip90i;=G(oY2h z6)<|^*xG@h8*V2#Ik_QF?iO&~qzJ&L4MvietPW)%YeD6XD++-4vbDFjcTof>{PtA4 zj&N(Lwi8G+1dfcTUhemDj5mFF29#?8ftgLBX#suv$S!Lv0)oyTpMh>PXi574RO0+W zKkFR;p+;ZY+f&A3EIEANHTF!K&(Bj?(`4^9;tqA!XZE_j;()YvKO?8{((AnUYr<+C z8&W_yf!xkBg4NX*&RclgnA2X#2DHaj(#73d70oR5La_CQJRQvNh16eOVhFuwloR!? zE0uJ+KM!SB$0{r3Mvu0JdAqs>c&$QkalXG@kZCY_vEA?p{$oWZuc7n14&7szMq?>b z77wK-7r>LHZ9+B8!TO0y^9?jr@SM1(%MFPK_O5okL10cqnk|cp98uMWX_( zSI^ypgLWXg4csfi99V8CC8bCb7vn*|_qjPI);EQxP&D|W8 zyAW4Ufwtcq=V}70n80uQb^a8oc)inna$n?Zk+|u6y)XbQwQXmcKMqpZ6B&yHa2_Ew zvtOJ{`ywqJ*t`J?zavPktkfG9FP@WBFX=yQv_{{p* z0_6J)ssgC3v0r!~Vpi`Tfb1s3Q$}I_Z!F2B<0^VIyPi~Uu6^~lEx02=skUP6h8wUX zybW~jzI{@xV+J$JIxw_&_y)pEN)}ax(KMza*4|Tf=ChM&s_U3p^{Mk3Rb`p7B8ktA zm#(kKjF8~bj~1bT^zFL;R@`8u*|+@sm$l7_K)of_YeCVgL*ky^UOWRPNYic`4k*|C z0JHy}Lql<`#>-IjK3?%6Y23(5Mm(j?CQhzL{w&;X_SI1a*}2nRju!o`$!ZYOpn)U| z&tKff>+0M9xG)N!8j{3g{&)AnV>Ki2sm;z_yfqpip#R#(*4;8Q+`u$0~NH&*T#k6ps*qjbrwJRv+aDi&V`>jc6@WTWU z6hI)%Bn`IwPoCTbvP5-2_02~K1@9lUr&~2T0_2igyH5C?C`~|~I2t?vyHoVm9 zk~DFme!R+2G*^TRJB(4>vWm1g}apH=oH?xt%e!njFIq(LjX;=#jF5u8%}SA)If% z;FT5y(#0D1avUx%N%p)htvyl!@(0okO4;6lMTh6e zeGS~!JW~1kGx6hF;|}D?kRppnnY$l9egGKcK9ZD}OTQJ;MN1>E)9gm7Rbz!5aS?Rh z#1bPHK;=xX&mRN*uU3;)cR?fy)>_Kmo@;n`80bx5hHU(v4Zn6SY$TBD)cW1%VjD80 zm?{Vq`6lTfKhB$+viov*>CqTLT6v&sbEp(cLc%EX*3%mn3x<>m9OEOC)x0CqA@9v* ziUw?LTDaq#{{^|XgG1jzS3FW z30Oz}7IfY(b0#(cn_1-Cxj<(=^Pxk%cVE#{SGVz-EE;Ok0C1m}ommaWwogRhr<9Zw zs^+ho%2O){(v!xExx0O5@5(1+<#g_>k)O_15^q#3ua7?H_yyG&UBC)=d!h3!>D`I_ zUKo>w$HnOVkn0zXb7Nu42w|umv6JU^aBwha8nW6KGMydzeQJNC*ToGN!N0eY!9<{p z!dEG;mqb)j0a?}|j|7t6Pn2v3o`daFkD|Y4pTudA(4NNc{?U1}p0>cN1}^AyU7=&J zCyE*I@=*+>*Q2Q#GnK4a-}Biai|R?Mw5*#gBZ6v8>(fCKPo$h?RTWSB7#Lje*JkHG z8_Z`8Rw_^}QP*wW`xiaFJ~(m9CMTz?=_JeF8~rLSJlI%{qb75E;J14UcJ3o9N}!Ph zsn1mByhSxyq7woDn!i)FEgeV_-quzzq@oiL)+>3WIb+9^g}53WpvW{O=LV|E$Yp^?dKYYGS16t0`fFgISjVJ3^B z4IZ7VP+aju=-=q%@j6eA`?Md(FFidsHGhFUw~2lI3GE7ROB+RuAzq?*h@OoJbC&&z z_R4n0!Ue(r0D|Ldp}n(Agmdvwl;+P)nEP-spU-QZuVW4Q{a2ut<56o>+@Egpy~1gA zrdXo*Q_}kA;qLCsR|0nO(YwEf9cxc^pTrCESCif&nHyXPop7^(*^b*Bxt)*s!AQ#O z&IAKRTt>x%HmMQ7x^Mg5OML!8z;Pa1u*yQS&w+HW-6QiMs4J?(mjEaANi%nvO~_ll;4i!R(F5E~PX zf8ai)BH9&))2C;y3nm(gIvdW52Zp~K4`wQyn4hwjp9*+Z=}T457gl$EPl9~ZE4;7} zho`a0Ic%K*k()G+7Hog1hXJKI*w)<@$X5uwsVivdLIMULAIq?v{q{)5IyxqXh?EqX zcLC;}Izbltbeoz}-#cUKjps4gB0I0{{HxkY&8G!kBsnd3fgyim2RZ#qO2BF5}3vrYY}qvg^k*?39){KDjLb$Xzr zCW__gL*Yf}3sGYl#x`->S;bd&^j|%$tx4j){QYI^zPD1-Gxqrf8%0d?TQ-}BaNKrJ zm!ozUvSt2JNu_kr{ZLdD6mKm3ov4^s{S7|h)|wQ@fQ!rGw~9*E(CFHS>Odwv#X!ZB z0zRXGu7Qr~s!gie&8n61G7_ta3rFLSoRtz~SLU9c-YvIKh(Pv|1!7i_*LsPHR0HzT z8^Md#kQRf|>^@@ViFm<}=mvL~=J4W~(*3=nkf8)R>cAxi(ullSzA?8q$INfdi@el3 zlR!|cg&WiZT}FsNWB)vDhvNxxeVE32@U_uMO{;4mFGgORL5|yT-;RxKs>o?!YetZP zxs~9<9VuAJUDqcshAJ&FT?OjIyCxU1D;6HuW$hA9|0#Z(3Rf(wnGtq6&eVxmh}}p8 z)ij0Vc=hQdM_g9-i2cKa@W-8I+E^lTM5?N5`>W`& zuyAp!&KiFw-dvX|BElfM!*U!3U6Bl{)Ajr1U!ObX&R4(+I{-I7m+S7F;H;XyKD|j_ zf6%#N(K-FBoP{intv*$zD#7$IKGiuenHYNuE3)SwCq6pRf( z?SKLd?IFQ}$_un2H3~8kk^~yH-wTDhdF!$$a-IAFEMcHrW%-Pz_|9`=5?<{lO(Nw{ zWA_{bJ7onGeXf~zOfx7zJnP0JNqhA{ji`KP7@=S~jhkBgU)s}cyf+qYCcgCR~ zA4y1j;#}yDD8zdVG6l7T&V2S?Z=5VHrYD|B6qkrws^8U|3FQ$}d$_!@5f)n1Q;m7& zG$TEqLbEtXkAJiE=hlUiireq9*-`m#4SfN9y&Oe+Sj`$RWE_TflXzM-!(isL`5~o< z2z_ogBCP1hQwt~FZ9=1&cO`@_N>AO&K}T&st|+{Etg{5Ivv@#2=ctn}xpr-B4J9@# zuH(jQ6D3eYG(7)N6(7gY=t6qh5_YTA8w9x9ygP;_uFcV-2xrg^qpiF8>DZ*k8*(C_(uqgO#XvjEpaW!K3j}iS!MS>rP z#>U8%m6hu}&bWD~!X{{B zp|bgi98b&5X_p@sX4U!=RE?uHPga#JdFtZgP#|W*bx{-7#HQ?GhVySv&&#YXCECl3 z$6m(l^mb15F8bB+4n-LR&?_ls4ZV|Vds#SEP9_$V83f>pXC7B+kLh0SY%dp6n95U{ z?%7w#fGV2J#`x%>m+i(lM&mJE$92}{I=2_JX&mQwW`t_QZFEtRes-iACdXCdC+6Vp zTKT6R!+c=Pvyt={B4&hS_l(BuYqkq>Er|Opcr=rLrM~!9y{ir*=TXc1KZLz?RFr$z z_d5bgmlBdDpp*h4T}r5+QbTu3gEW!~DhetoAt?egq;w+)2#9nejVK^plIOZn_Vc{! zoU>lm+JE5Qdz_j37gv10pQ{Wt)YuXpGEFeit=JUQ8#oRkvJHF~c7{xbZbeW53zoy{&W&V*N`%?VJ)P^jC| zQmTXVq@9?eNwT%g`Zb^64CfMa4jrw{P1|zYIzj7h{l3A6DyJeb-1pg7`nEULI4i|x zPj2Bzx|b52k|k~Y@NVZ#yy?EETGs3QsKS}F#AahXZOu}5L2u!vrPIYWJ>5Sh&j!`7 z5vwPDxq(%y^s6kNm=EA!eRUHK&D|ZI?`^U}G|O8HGyF7X(puD2c*ZueL$-QJGQ|S7 zb32CQpT`QcD8{CZXj#aIxI+3r0g(|5Wd}k!!YB$;SM_nf%!wazBpZ?+=C;oyp3HRs zwa^yS8^6{{Nz$&#$jKSQ0tOK-EMO&&xr=#jBtneVP<81f2R+jIjcC4bb_pUEVr|d!lw~@F!oINYiiSp!`k9a9k=^<1Ts#Q^n*-K53Dh z9c&d#bNA{a#WOGZn1+i*epFs3P2kcA>3-R~^BnHccdlzSoum67w*6TgwEf>bs8Ol8Z6Q#c%3M@`z|+U za883Lnky^E(s?OhDqJ{=I^F&1kB|t@9x>NmZ`O4~W1Y!coLgf)OKVG4MH|$WJ4w%I zywHM9HcN;a6E^Q2OCR8)rc3AU_5+fIv5P!ebI zCFR?+0+gO!Z>6N9vCi+wyTL7|DxH2i-C13BrM4&Cn1sDgjr*ZYmQ zuDME2n>eM|bM8-(*3D()=AK^l*`-D;{cf{ruT)`>eiEQj>D}X7ShmzgS+(Pyb=#MT zaqJbt@qp{uN?vXr8@1Gr4mu9(f9#sR=&9cGaFI<|U?LsG-bt-BAIF+pH3&PEqegJkqtz!tL@#4PA+!QVXe>nIxdSleM&D2U#&fS%3~p zOHYgAr7NI?n_S>OGLnUYGkSI?of~TFNP8rbTi24u8qaCG5y~yg2*BZ;>&T}H_f?Xe z5%(D;aMF(#_UZAJeycv%v&+=T!$IgA`};d~Zfali`%4@yxXkzaK}w$t*Ky#a+$-i_ z$Ry~Hawh1qpFp__c^I!ZVt1PTN6(s%*IF7nj*FD^nSoVS~B_Kj5p6CG2`YGv!3b(D5i%T zD;_m-*w=k}or^Q@O>ZypPmKFQW&Rbl?42QNeP_dz)J3_Re}(x!j9VU_NjK+c0>|3o zHN8U!I;n_`oB32wMDvUCi2cdKx5;i*cFu*6KUJV5Q?F5vpR@r?x;fcxy6&a)@yyJx zTTBCYsAVW`N97nbj$F6}1$1?Z>;|6%X1h%r7ifb~CVl$yHBYQMpWAYKIN&%wDBNSK zUp&UyO;q)WDY#OYwteedR0_JQA0uG?!O}w6(p=KyGN+jHhEXBgE*JWAlPRy0^>;?? zd~(goc&(IAdsz=t7~4F)?pZeh z`KZ%ee*=Y}-u)JczA>R?$L(ye)tTt`)KT4BQ2%OVX0%*(caXNAqNhTYDwwV8eK4IK z>mwCWRMrZ7MVndEPkHlotiA=!drcyZ--HL&gln7m+rmA3WDLi;f9}kB6+B=rN=;L=O$jx-5uicU(;ZX+1(|pGjim64blQeTPCRL)FEtZQZ+-KI#wjifk zo^fOE_BPV8?ZK*d6x)r`>kEj9G2Ivc^S-HwSPF! zVYJm%ze&hbgU|TvQhIxPBgO!vC!B8#HQMLgx@9m$!s8bUlC2E;L^n0x{f$@t>pl_ukHGLRHvf+8{T6bc)+H7U`wy_i+hShI6V4G(Lx ziM}pw)Nln39{%=Zaljc3l1dF2%yYn7<#5iL!GsgUg?QZf6ePC8OdCc)YZTILOw^i#uSl#>}hR$oaV;4D| zb>AnK?m7|9((wBZ)*hK2^cf~0cd}TQg{QDWizRaEGwJSYReH%MJr~PFzw2?KG47nU z>Mw0`^sHQJKTCju^d&Qp2SNfU`&;yU6*cGyNG!S;H;8A$G;bY>8icgxVK!`q#GidL8MJbnxUgw7`bJvkep~H$*o^(wlu(_x*Xt8CMm3gpTk8B8-9ElG zgTIBLKs)S-4~$i&9fyqx>abi&`+L*&P#4rBz;swr($N^LZfZ(f*ar_OsH``gp0o(x{(0@Y&pIb1n9@fWG{4v!PYi`)_t_UjF zCB4_|outP?LD0rOoc}o{rd`lsoy}yn9%; zaCRWR&&WPeq9XiGm~rxjI{Q)Gy?M(do$qaLp=-a)M@zydh0C;LNG4yMR(r#LX*b-# zZKiAS80w67es^w88On1Y=5C48v@l9{Jz?-CUc7_zk`HT0(dj8T!Lbj~HeMRF$SdV_ z-M&>CANbo7gy=HUE~?TNG8Srdf%NJ;IJxE9>2{kT5V{gVtDxU`>~sigyj_wBUt`@NC^)3%*LxG#CH}J5EOg6l&oiUq--RcT^bWn=2kFI@Vf`7_ zq~NR|eoyvJ+sFv*xG0NV`h&jy9n!h=W#x&<;~e%EswsRHg`bT-_DC^~czvm+t$6l( ze`~K*zx=P~NqZg!9T-xx9o~-*mp9TR&bE){+)Elei1(_sB<>XWaGw{eku8j`O2kiU zUzq$XPIS)h%UVQTrLC1+A=*l}*YKAh=(*AwUKRZDybGCRuoiBq`ERYNP*Z$z=Ff;= z8S=RaiGx4n0QCvftMo{M;^?!NFQpX~DQrh7o(~V}#EH0)?QAYKj`|z~zkh!b3NT|3 zJ~x1a#p+QtqD61&1EIwyaKoZd$QGsH)bNFgC&E*nPAudp33QW0au%6w_N;$Z66%i$ z8ELQRy#&Q~-%nw1n{L4ySOQmmzjK8GYaP;Rh0$i;7IqL{8oVd96lZ?CWG&nXC3*Cu zT4a^@55sJhNdLG*l+RR8Qh+)Hq4A7-IQK*7{MxGY{idy)R<9lG_20ru9Yg742=ZKs zYN719G;ESkIiA49Ps+uC-R8un@X!=;rz!SX`>p)s>4|~ec&03FseCcNpuwM?8+};C z{X=YGQXma1v71;eF`PE3@|I0h!}4R#B4=kg_hes~!>#unL0fHBivtz5par=Xf0vig zhE3G*O4YUle>0l>(dO#n+x1JNI>uD=OW-+#tHgeogk3sk2nf ze&N=V?&Ab|a>%o|aV1midlt%AT6{UzHs^$lu`mWj?*B-h&$I z4`}II>Y3<4IaS}RI{4`vKEj!QdUg=QUUW^J8Wr!h6>sSi$t89VH8%M>fLuQzu>JMf za%{7;l847c#`1;bUN#97s%m#xU3B&H%`Y66c!#-tEwPAV$^F#BBOJJMY(pov!?urbNlcGf1zCk}qh`9E%>&8qEtsQ2VX(he2c;r8W`ed{-^JNartn%}K705@3Y zGW5_Ji+x*|s^mr>Txj)R-Mrkg-qAu4_0e&4c|={JXImU=)nl18*cA!gLan{cU_B2t z?gj~wp^KbZ+21XZtcL36D^Z4H*sS%9xq$&A3y!F!MCfa9ah*EzXoqisX}Hk{tCN`A zx>ni(bw&!>Q|RAMy$OB}7IGBXjfMdEY8)BI2Z#3(L{*@6d(rqJ`57R14M!^6--^2P zLU*01EIuvXg9zCc(5+Lk87l!%%!{GF{zjTj`KC@Uk0$Dp3?+uLK9p}xC$TuU#+jZo z*=~Evb@hE=Hben9{xC1dH*4s3l23x%6u~30LxY)tfk9yot@(as7^aNU{>GSwhQ_nZ zOwP{@4WQSp%k_d#OQYD9siLCd(9%`>-6a+TZ05@9YJ!n)p#4&3pD|(2{Mvg>E$VDz z3?`{Z?D5WNoW^|Goc@Z1h`C=V--h7YR`0XFNe?nJ9&kvDSsjyxoyP35#r)`_0bMgX z@_y*t<042JeiWS2+z7-Zeez@F>jX)w+g$a1G(&ovuqUOjZRGgY)2BDjMNKX}=L3~c zW%o<#iVYwAWZ*^RXAb{Z$5h3?yqSDggfl@BgRvwKS@28RGToywnfWo^mQ~KP4_0NG zzW&+eBHLot-}c%9uNo)YT5sIZ`f$6Spr5RBZF}P<-j&p}0b@1i?6lN%)y@3L#Tnd= z`GJnSUcKJIxe+}`hxzHb#C1z*XV+i)WTSK?_hQeiyzxl_A;To zj;+n;Q=1hEm3Bp^$X(3p)hSeX_{ir{`M&lmobMa*S2KSNRi3h)YbTPs_pUJ_+^Irz zlcTQcOD@H*SfYC2Cl%;f?Qah9sd40S=B3sCG}9DfcHj27bVWVddUIv%^9z+T`GS2# z+~^>F!e8_2J)xr{hVrA#@0PY7o_gc95Zc;!`MPkY{U1w9*(@A6IJ3*~*J?xc1+<3k zu#Ro2I85djeUb}(d@SxYi z*$x+*n^jrc?0TA;yq?~FtNL}&gH&FKt*dv{ICFN=Y1b)Kw^+aMWM{lu)<$*)yQdE; zR?SSU&vyMstD^;f?@Z@FzMg!NjiT?Fc%M1A1fvWZM^gIpia!SAI}liAv^qJT*-pCi z7@xdh?9<~n(A9nN1P?+Y{L7awLyYeSf^ArcNDM>*hNXAH&Y{*!hK6OMIra5D#c#5& zR8W=Gbo_G$L6z|fP7`ipl%5A6Kn`VOWdqOI?{3fK@BmbyGt3-rEl2BLdwhY$nCqN0bv^XqGprxP0sF&UP zZkAVus&o5bsAYEjIvXPjg?D_*u%WsEnvLoV@9Sk{PIK$aw@T+^^V?x{vz3i^Bz{mo zlU$GR^j)9%0EctnMlrL=tS06>DLJXX0!3os@^<#!wzQ&X9 zz0Lqr^vlycI2UAJ+BfYjj#i)XwGEfMIi9MBGWBgpuT?s-1Jm4M7F-xEJq3Ze->X-z z42Pho(_Z3Wj7(L9y1d)Ob2h}TI|e27C=a2ji!qM;%^c##Ir1IPU#%IXtr{8OCvwDK zSqplq@+^H56)&}yxij92uY9%|m6McN!FY$*?_0#2>?#9;HW}lwx$%U;>ybTPIU65U zo7BSY2%1CFJ67DCQB;ij=SV4&hl|ZCu5z+5S-oS~Y3Aoy**j~qX5++t2JB-S*dm#D z^z+l-Q!Q5SbDHo*-<+Y>!X|r;E$nM-?C6)o#Q4422{D?fYQ94t-=n@UlFbF8Gqg6^ zm?Lo{*XGsu%9D=GBrj5qEzZ69EjLfYX8DIoYHA_V)poTjqXKjxuRu5qEfqAwJ~iAQ zm3<#HA1zp}@;z8=YbjRyp{#fJZsRvEIY!4$4GO%MbUajcUKJ=5Qqip7wMiau*m&~3 zchIjl61u;346faOTr3&bH858#hcK6!eHOUtp2~4{Jf%SQ6yE93YX3m{)6R=IF*S&- zlJ0;WX=P&j_|0_2({sLIGxU4y_zMz#_vZ08gq$`iynaY>p1qx3y)uH{E!?csN zKr3U=b7shTULY%;txv>rgA>^F+!AN*xO45Rs#jHn_1?+%xA{OALm*J z|CxLp86CWs`NRXQ>4_{1&AtR$r6)JF=|+31*~GO=aNLgvY(1P~7UdhK#8V~AnfJ0~ zLTeru{J>>mgXw&!X-WKm!y1qM=CiE^O4gy@q&k+JX3KZ(2!rT>1hr+TZhcaIlP#2)7ndnZG>YRXr4>p6&MgRU47Oqn*y|cY@zsCK3g~|3|xHt?-w_8RM zf0rlp#S}K-`KCWp1Wyd;E3R5&<&QJ@1`*ALTlP-&&l-687bK#-D|{Xd^eHsN3Kw#t zeG1WiG?l&!uJKhKg8X8{#OnKo{hRY)ug+c*q)7S{65TlQZ1njy`JU?Do^*DeQ9Tjj z68bq>Edg3Q9N-$oN%$~>v)LEQ(7@vf0@-f4M>N?aOP}*)UfT@lI2_6L{+^2!b+b0~_#q8#hPpFW zvpuX9nA;_m3eYp#iNAT>PULD0xwxJlKZdL*mHXKGk0*QWss*>J%{UO$-nQLpGZ`R>MVR_CWKa+9pU?iZCncmFt<@bCQCPmFu^}eN# z+~^0r-6muysi(F>$aFx1`0IU(VC35;Q{hg#`+fIB9Y+o7@<8!3%B^e6r;u^GKUMJy zB{gLr>@uBvKKpA0j+Jj-_4wEd6zF=-e@ZGOW5~)YHfmLd8|qa^1+nY;;&d*Nuq5qC z(;vgR?k`iYkAD*uC&XT(M==g2cQc1%D?5Gm;M`MCvthGp4YWBJ+o7Jjd&3fC(Z|b- zYr8nxUMM_^%LI~762|YJL9mVxED#U_Xwz=ozD)=iTl+S#&9P|ncpmqJA&N^J;uOZi z$`MB%fs+ukKyf}!(4lr|9crq6AoF*c?QY8DVq?Si_TDHwp0fR^t?SfuJAO!S+uPn$ zqc4=l^mf!D#=C#5M$QPbG8_-NJ4+!btfI?4A1B2fmGp3k2Ibq{uK1#{z?ugCN-7%^ z3?@M^jV)#gI)4xSb<7&F!i|`Z)VpFL4@krhyi-j+S=;zK=n{Wrj4h7USA25dIUP-} zj<2l}T%8r-wWq{8)^_g^-u=oI=5C<-=u{pX zH>Pk4CHY%mIk>{mMV2sQO7^X$|uTI=ZtS>Et8XAWTyqMDQ?bG~Dt z947QqCZHFa!bn&0nLp(7g-z0i6=-o~>I4WS99Zqo=hw{|tk-9V#s;`r` zBipp1e$u+ovGrSPW7UP{p{gA{sYb0G5rMwOH)aPvvJJZwFSVgL#mJlqb7VUBvm8x% zqg#u9ZTs-eOb=w2v4q~3CixE@Gjx%hQQvDFovB~V3`?H}b(Wo-$gX;?$k4~kj1^K{ zM6r#0|l z4xXmYlaRj(I*Dy~`yfdjoJP~#Gv5Z=s8N}t6)nHAUN(werP6jegNNeWbtk5i-#ki}H=v9lF56UpM6tiiyenz(@X$GDgsdEQKnw@E12bcj4SNKVi913cEJX z2Mt0-rti&#A2^TOXqNtQTf`<{Dzj5KNUcsP%lr8m{T{%aWE!ldjodBk__40r!M+K~ z45Rggq0D*LS3Z1QSuOHOI|7N9P9+{ci;F7!T{>~7D|E676}7kSPN#$^ka;j293=@@ zCQ1rLcYjV15e8toA6ux}u|2#B)b!{Cn+=7#TEDZ;4;I@7;x?W<_EYP@`h4d|?iFZM zvllLSX1wydFg85!CPd9SB&EJ)S;tiSniX|Ok3*V4Z^WtqR-6pl`CEwJpk_uQQ*zOk zDKUL3y!=uS6t}zTa9u|_>XNMTHXYm?zY>TFnMoJ&Ninq!wncDg(Ye|}8U2mnt(Gqh z(Zo$|dwX_YH++OWSxtE9T+Rskhj%-DmX^6Q=-)xfN6xSv10q&$4FIg?=y&CEAldD} zAY&s`V}vUrtnQ*o2dxXcumAXY=cX%ZMw56iM%r~j#Ig_h-imErxaIFHY}lEV@&D4s z;7CR@F7@XA1JQ7XiQ40#?Ck7@!Hqz(3F=ESQ1U%&;QjKvnLonr#$JRT3Pm2ks9&*@ zJcRBqv>NEP7g`=>39DJYBNjK>|JuoN2&yWQR5Dn$ck8PR(=fZQ_*u@nSE>Bfx!FIl z=rd1mHnthGeVV*g6&c7M>tfVh>59`HfXeu#iSpgAEPJbZLd24Z-<=T?OE}A!o&NUw zew*QY&Gb7x1X&=vKyg z$X||P;EuYC=Vm4S?H};uMchfj&a5O9PZ}x`nDw9n-GoK>sSZ z^67m^EcNpD&QTs~#WNtfG_ETW3p5TK9Z(P5aY<`TVvW<}!FE_IgS+tht3QVI!ovdZ z6Xfq{l^bi{a@7p0OwvuPKC?pnf<(XX`aIFnrc7e>75Eh@Y+VyP*(Ao94$i4bA_3j1 zt_vgV>fB+_Q~m(bH)%k1UbHHqW(2Dl!boZgVH!9ypHXIO8C6cC z0x*aQbp(cJfv+N7+>0J`{f(a<<3m{jKe%C+H0UfriQ95AU*yeuie|fvR!)54;qeP} zCjy?HxN!d3>1ZJVdk~$4c{N?m2yD4)L@(s|08dxs3o*ps92|`>1h%6JZ7TU_e+3<` z{$xn`9itPq>x0$7irXZzBVrWQHp3wde*!crR>+;^x=%0^{~F9VvAvUYzc?iHdv&=e z8UJ)<&;XDitc%_q_k{`=qrB)$r!C#Y=Ra z|C|dT%VJO)BPh8a+jsd@@MnFIjpX-htlm27_})5bA=i7UY8MO-cy41!zfU{Z%R#gS zrs=rG5&ImK@<=jb38Rz?4zt57>HPHX&8N?@Hm@zlh`)5-={m!yHyU9`!a>@kq{Czo zRxt8t^)9MryzXg7vJ7#>dapLCUkU)&!LLNdEcz3}(klWZYwD(X_lDe#uq?+z@k4Sd zR*0aV2S%4?gr`^~xuev5STLK-s+Nfpl^GgJztf%OYVzn@=EZPOf?oT)C)TnF7<1-u z8XJ@-1L-WLblwdW^6H2?2oplYIkdE%@eq)gxLA89Hh?~qq?nsqkCcvD>)-x$w}RJI z9P~b|HrjF5P2gW1eoE>HwGR&&`C4)063_R~^&UH56gU0&ovHS;8hvQW)f(FpAWl6; za0*;n1-TvyIy#I9(UmP5e) zFG^0uM5R3~tkeDG7Gew=`*i!xbt)R_3rL7kq~e_V^UKvI0eI&`1jrljvq@3O5cn^c zJgutbX&v)r`WA^w%hlRt$+Hvgv^3axgVY>o9{J)%-6PPc0-?9sWg0GV?^0k+1 zr@uKoj@V#m=E=xsT4i>+fY4?AmJ89h&GqE)J8jJN&L8?4^?bD4 zH?chv&Fk0d_F029$A8{=_mZ6=Bte`!@s0>Og$V5^`i;=Z?(vRGV@2PGbHcn|O+2Bv zuKm6D)b<_hC$F-kd|dPEYSh$}L^3j|&4t|xFB*SxG#yK{@>KJoRN8Yy#-;(-uK%Oy zi~2n%HV`9@L@+c9yBe;)T_$Pl44PL`wL_ zY;!k}hBFjj#t1rOI88>vx1DXOqI3K9O+zKZ>{PQDRl$Zv*>ZbRli`*5&%Z7X^Oan= zhgMC-k$if+cV^*POmr&r|~yZ~l8ZEdZh99z7B@`xf_W zBBJ*U+?U(`1CDw=+d9@)&i3~1qFqHOWaw4{xWAQLKNVM9|tmdnO(qz z#MsFahaCx&Z`bXp*N~Ws9B>VbMz{6AN_H}J8`1S^7(p5)bJ&H6L|zs1%L!8$W;I@5ZD54U9d1dtDh9~ z$uhj$r)8NjJ3<}Okc6Hnl?PES} zNIJuO(*aO$iQ04pYWws9Yk)QB%iguU-^tg6dOAKr4$`NB2~ z-fYu+jW5%VV0ViX^*q*>ZH_|4e3s>*aK(j#^7ET8a~O04%2WnOsZfH|9hz$oaU`@$ z7&GW1B0@l^`xc6sGk7O^PC0bE32f<8=2P%~%AA#Mi|3oo)PFDH*=#Y$kP@K@kgX?M zX)ao~{GqB=39vzjLgcT@DRUPIVV!OVT}th&#Fa};gSJO-Jw&2{x=LdBsJry}?=T`) z<0nguTFvUFp(t2xBvgFy8{tLFfl)F5AXGXPW(1kZV~_U=Z>vLU!)ta zbCE;^n-jWNwAOS)3$ndV_vW+B@Zc}xwWpHlpb$K<^(R?Vr=Y7(R(gscARu6Ce`gUv zC%pgg;RfIcm@W|D0G(jgzo$pygWm=7e~>8WdxVT(VPQO8+fIlx6f{2&uV$&?qL9oA zMuUwZgl8C4@D64a_{Sy#%qoaTPQVTT$G^3%*XN*%>H(ild~CnUqW3)yqR;IAF`gvZ zr4diXcK%~jLJs*2e;mZBW|ct~`1|sM=r7+|@e{E?=F%utkX~!F8V_=$nA_tc0T!`Q|9FcGtb z>d*Upv2H(`NlUAO85kq6;$FflD=Sy+MnrYM&j*lEt*qX8iQquL=t^K6uk$g=98p31 z;Dm^^t9s5LV2gl+H7Y&$7s}?i!}g$+ip;)<7xR1%v>T;prLjw=*%`lAr_N4eGt#&H zX6W)I%o3+hAa)yExl^r0X{KhAfXe#ANJ%+xxyUFaMGY%s8JI1@4+Jr#oAxjz;w6&^ z&6TRH^WZ`TjFN8Udyo4MRL9*^e3V*`z@7JCw}TZ3v4J8;G;AXb<-_aC(JwAbm`!k; ztWnaY6fAT93+WiF$WbiZQT8w+!{wswvMAVc@>^%O04(+)M$2vG@w0Sc2ib@n*f)ti zmJag;0OJ?EMi9&pP;U7^TH5&}jA&Z2-zR%`F%0y~M=%;;C{uxo12~sMqTt&IkBB&| zggB~xaQPVYVnMV{>5PSD?0T(aOCfTIYnc_b8Y}aCbl_Z$3sVESs}GWZl??>b=IF@5 z#emXvzF;0pspQ4)ICaZ!5YvQp5i&UM>i`z)XJY&Yh3$UmRVvYc1rYg*Ira!q8PFZUP}((yQCSEf3S|JW1-Hpy zfvYej=}ADqIE?HI_qaSIZF)6k{L(CNQw8iz;Nt_#D>(hx)!0U9xxgfdvhHb7!o|)K z)zn_{X2Er~XO2$LdEq2VciLFPK9M6zR$7$}bUF11PYlSlm#ByU?SQxi4*DH?yzi8z zMq(cA9v+fk;-;uWK&S>-(Nl=o195sE=DJ9-3?OO;2M1(J0JjU-1#g;x8(hDml-lW~ zjhV&e0@6R7MQOJrU-B{Ik-4bB{+_79|ueH{macR}f zJ$e3oqmbA3HAHDB$L#*ONB+T8JSpEO9lw)>MAJkm0`8WiW8>jngokfO_5Q7WvT;D) zN+pPUL33_t>@oQSwDStwhS?Gx`NM<|J%EXYg%khaV~nsBt)nzZM4us97-GV;ttN%` z84^-Mq>rH59vhrbrf0v*#W%Fk+R0CWN4LE`KkrWsq+MOWP*hY_>Q=ry(I~G(DM)`1 z7W;1jc>=AUGv%&}vaoQom{-B!Faav@E4Ym(-hUOe&n)=~h-wd#g5d-mZh+17 z`Xm4x&ZeAJI?Ui~N{KuMjm>?{5<$XE7*0^8gj`E)-{q`jvLRoY(oN^tc^MOvnL$xjF%DXUt}^4ob;#doY**(G zN`3}AQCsXD1zviOr8yUNq)S~8V%*aa<5#Zn%ClZX{E`k|f2eDo789=0DOJdb}VspYnq27+i9n0A`8| zuyva2B}F{$P&dhsl>C<5n^=+jJ>*6T)?U@aqs&nT-ZNez*p8l&5a+oofKSKCqHFjs zgDj1lu?NPwq>pFnABLyA(@R8+pyj~sXn}PL{vHE(yU6@LgxwLgSM;+!$JcD#@?3zHeU0I7 zdCq|y#T2O?uletBh1V>DKDi$8uS`u(Lk*$MMT&xgis}ilWX6)7GCMXXe<5scK?k-3 z)+hmV@$e&Dd&LF{*V<}@8 z+x-UU#ZgdG2du3*Avh_7;DuaekZ1iprJlwiOz?2D)1cwMO;M`l?k+Ut=)R60>Ibmo z8<4IZa_l36D&ZzV4*FB!Sbz#1Wk1ge z{SBP;;bmQZ4+a3xD3F(7zF=06p0K9?Wm|x}vgJmB>tUO%dehTWG4!tY%cvobm zJxf=DCKUe_v6RpdO|h`Bz``qVR{#v!ojZ4soBMWjb4e=9VmaiG!@6_43dXB+OYk5h zUxEuae^OLbv>iZ_fM`1ncB8=jD9xPc0##_?94=G9^wsc)gFoY7PvFP5cUJ3fVT$m1 zadABa=cW-e{xkJR0uvGu7iV3<`M2a~4Ag^cZAIMos;b{ADl~cO4kr!3D<-p>D35*u zjuN!t3=!@xuuWKZ0I6{Vn1?B|Zy*l3IBki{l2SbSmI-yM9?^BB-FIE*rDv3fn=0nt z-54EOQHxBnY5*!uK!W!+A)u9{U(j$NI5ptq=#5V8Crm~~l z4tE+Do^PjvnJQcF#$yxFx;zE_`Sgx5GoHhTFjWMh+uHK76qJ+A+Tw$e7B19`5%$op zt~7vMKQ12eAVzxe#l^))9C=wCOBxjrw0!wa>eGmn!&L<=NpLBu3Lc>((dXO)E}#I% z4cmtpa_t*1${A*TKwSdmOUs)IVS7yMwFjY1eOA4hQjU%v2OQ{w8?II(y<21ilw@kgqg72!1aErys4jBB z1_4|iBU0xvGCB{pF_e7UlWek}6v#_H9fj7;G=!!yg}vyOzjR0emfHcwa0`+| z(E@W5j`SGol)$X7C8?PIzS2sBO$IQ>#?Dmv)(xn1j6($2@Mm=f3Ev;Bb;yyGkAkIU zyj5Tnrt(|?=sc54`N9q4)QtzXm)C(*asD!F$l7q`hS<53Q!;mP;i|c`ylSwebmHJfdkKJ zj)#o#T+r*5X~i$oy*ZMHt3g!`A^n<}eP}iQ`0*p7@HZssx&H$p;%gt2M7xgcUy6qXJxB$se9P`>$$@R<<18V^s$#6Da&M)g6|NY;3BVG+nd z2-+9vHWXO(eLGkDD;v?ODoNB)(jRr)R6@Vnzt`3-wV)UL`0g)D_I4H!L{5vkT`hB6 z41q?&Scu4iB%Ic~yu7&(;UW}9WQA!JSopzJ7&f;7qC?qDw3=EV;InWL`-J?JO4Z3i z;7<=tZ2V_wYRXu)=~V--USXpfL0KF7oVj2mb(ElUu&PA=s@D0BwEFiqF0-@WaB#>o zd{3#Vrm8wM%c6)|`~Pidk~$GGWzWimE<7I||F4{VOABCMu&E#ngQ7AMST_v%(SL+@ zEo?1V*#2h;LPy(RIWGga&O&eQSSY-97YZona$z|XXacE;anyCn! zKGA1*+t<%U(Q{YDTdQ260thSoWb56jq&R*0=fBT30KpiPR@HVq9Da<7WMCMM0dMzM zekA6|H2%CbUQBys2RfW@uNvYA1F;8Kb22MC4{7xrzh{7NHWG9Gzxd!Oeu#RGUe>zv z7Yx${xe%w`&WqFN<;y=Igw&eD zKZI0ua5atl89#dTC~L4ElvB?3tUXsuuUcjhNoE`k{}5^fA7vhthLOH5T-GTyyEZ>?3fPrc z=P@^ICrp>lZtl=NDikX;y{~$VsL7FJ4oW(RAsv}W1cA;AfTm1BSdrbe$EjOJbzf|q z3U(|R01BljzSiEagoqUu7xsXKjZM2^1QNlkpgBUu(49StLVT|P(#XpGwfkoknOg<| zu`wwB0GDkNKFak+gGEW~IqCLm-Ot(yjzYf9omp&7dy$PcB4~%@|G1v+h0S0A&sT$s zJMK-vxN?%(-yAL-o^K}2Y90)(qb1l0S>ThC|DS*O(TI6w~(uOS>= zn96T({BnNlAs`~+9V+_2a430&_Tef1>`EiQBQMY6v1S2NDJJ1w{mj=U`WT|G$ro3u z(p!r)uY?nC*HE?-8T{*2gbqVQ#ADh_b&0BX^oJ~5ihO)3p*#8zXY`O)_PPRQ{3x%& zfP|&nvJh%u{;kGvSj_6I7lT2#zO1W+kyUo=o8S5{Z^5ZWiu2@gqy20+6(G#GoW zDZ!+da+lG}SUy!jOXmZ1c4QbG^7B3wXwG_iwySS}-MRJWkI(}d(x*|$D`L9QH+<8L z{-q8OXr@|%-saHE3HJca<`u}HYX-y709#xK@Hs02 zU>viuo0h5n~r$ zJCc5eYC)owXESu&;#wqj5#st>(-s;?x@L7y7@k@vx|0xWk?0V)msk>Fi8z-^Nltm zhlD!hXjk3ig}mQ-;Nx*Pjm6Zmb)Lg{hmwR1;{Xy)u=g4RWa~7)Ww;fs7P)&9$7pfn z!y_@oSa<;zK0h);0OF&-$;MX~BqYutQ!j^0`G6(k(>WrR1qJqK(jDjJ&ef`a`&w6a>=MV)7< z1`8fMj*DZA?Q^|2x!6Gy0i)g%+h?oPu&VcT*?Pz~s^Xu`f9ebK`U~IFnJH}|jV`ps zmo7wn$>E4E*+%?K8@12e>?HohC@$?li0%#dcmD{&YeTRJWhLZ%5rt-|+iXUfPUUBh zCbz)Jy8)>1qby71p4aASnEn~+n}-1(a{JUAoS+D#3_NW02(K1S;qMpi-^X0W!UW9+ z7XP*B9YdhmWB^pmOI%#_!3+XOLL(s|p?(Fc$7hqlj{0}uCbP%U2u4Ol83SZj5OnYm zGy-QHFq(*i)ZZTm3gH0fd-VSOUDyz8F#>iW5Oe|FyGF4xJxBFxTbrDPMS7Dn#L5aG z86FiZKyrtRe}8|!C;I`+<)$f~<2ZxOZluGM-$&PXUCOb@pr=1MAL^{USJQVw=%oJr zqRUi7aN}<6$o6W^DNCv{fK9FvEHChITKit|9rX4T&o)5d1x=6(+}va^wfr|J2*|>W zktG1Jdy3D&ezt~DWomc}X;jM;>Y7#`SK#3xG90RMvI++2orC`kwx7g43}rwfnLsbL zdvTHgl>{ck_{&sP4ls116qS{Y`pq8PONmc)4=xJnv$%%9kYcXeXf2M^3_-m04u!CuQ3jwCFzXwDF=I_U zJxtY}o3^k)Vi`wYrx$8H^UqGhBtyiUag^?Nel(ZOQ^N)*E_(;e&#>e$hADB=o;{M@ zVP1>cKY|knDgD$vB{znp%^KnBfus|#(Npz6Mh3!1*fA0v!$IQ%BR-WusT9B2s}|8~6rIxjCzquj;j z-GcdeZ9spervPdUuoZzez*0q_h8-H25Ray#qa!kZ4@nMTGspn6`SqJOevoLudpyo0 z;_?~xi-F_N|AlAK`h&+CqMM&X#a9rP2Go}jJ*?0EA2yJ5A~rU-e$(!a=9IH{rISV- zB}vL~ul%36Fk%L|iOAf@&n&#};b;8dmx7YwfO?MFi@&K0bT^HGCJCiIYi3}n;zwuN z)<;77paJ*@79VZk1nkc@e>FY_ZMO>m<3Q%SkueJTk}_N$gIe*@#s<=7SVlA;@aW*z z>*ag!;QoJU1Uw>DUOO8|Y>I>+yHLFb%t`T!-7Ha)-zi*%K;Yi zQdwF-kIP4QKjR_+9C_jKJX{VX*=R`;DwZ$6wP3W>QV5$3{TOh zq{5=(JVe~}CJaJ>&@1nja+k?VN$-?ga_soK&%gWzFIg3k*|(s%!2|0W@sNQW59NzQ zsxUkBFwckN*Z`gk()^G-*v0XE`7ap$avM@-0|yZ4B3(wX^-{obK85BdY2ahDYlbnf zf1O3>^r10n5r}mRVIiFsap6SBLa;mG=s`Z!7Nq)DA^byb&O52fNPj*Y%MP7Cb3qUH zqp^z*_MoYcINR(YK&va7nlb@X10lq$QlQF#IxKN;pp@oaDDIfTn3KPO>$l6NMjk(Y zybLpuC*hw%?PYB23wUVAzJes=n+A?hQ|m{^I@4$xIN6zr(kj0GJ^g4l;On>mwHDZ6 zfi@76iXuD4ZOLG2YRUlOt&-JmB1i;En_aNaW1E3V7{7FMt6jKv*Z;=No5+eVA7K&F zEWKk|!r2$3=-~mFNZRbf*MDYvJneo$1K3AN^=WBnCXk>PGJQmxldj4RP-bd?598uC z(ksUlsAiQU_WoBTjJ#-6HA{e3hY*Va70Wq!IETVvu4XfK%8ptb9y4I;H+!zTBeC{^%V_$m`ZT zeW}j^LaU>T=53ue2BsH>9YCt!05yRm*kcHI?R8w-OBk1Vk&Vq5=zj=1XrcPRJ3Krb zcHILkh9h%y%?1ssjic@Q z^0s*0GE|&C!?^`tmz3y6zwxsdFCHsrNP{L4>E=R{VsdMF(rvX>0InG!sc7V3pOU#G zb~Hm`joEw4@Dbz>7Nb>ST7{MYfF&Y_F!kE}C`|T*&;sem_i=q>R$+HSd{ZgC0ocDq zGuO$bNDLZ|d=j{*0uvGv>i3}L;0F5zNz`G4^?hg#d~WWEfMLP1Pn3_^8&ix>rs2n0 zb^{+4WjkClK1I#X2o4h7Qf~XBbe}G)qLzn)gP1)9BOnhlR9-&5!*`F!5{^pe=Pgkc zFo6XsrAN8c%iGL?0O4nZbM^*|@v7g`)YRO91GgQNuZPFahisClkib?bhsNqeS9cHJvFHI_!W~g1pg8vIXnj6Y!I6^`XmNl}#J4gL$2X35;aQ%AU4K;0Rq{$p7$-on z)15ttgwvmxtU*dBDryV0FmDZh-jqAI1xy0O#M0|KQcLmUIH!I| zfRM3J@xK?YOaq10&p`~PMU`8J6#1VhOBrB0!YYFjT40Wh(zQQ>C--!{(~cQ&VM7Q>DiUQl^GE$Oi67gVyD?b+Ri zF3jp3z@Xl&}x$UREP&)l=KL&LJxa%9B#o@3quZeU!kjjYXVx{@Q2MUfahO3SC zc=#Cmi%%$S5d1%seRnw5?Hl$Nk!(U%c4dd`kjRMay~-%c%3i7LO=OlND|?S3TL@)k z%TA<(P)NMj-Fkk<@xK4O$MGD$C%-3rzW3+8ulu^r>pahkB``)%rb0#c-}e7`DJZ6$$M=DHn@@0rWf2^7(Dz{7B^3}P!{ z0c-QugOO#l&-u;6bq4lfB_6cUS$S7~Yh*x`rA?-&g?uoZ>(}?N1jggCO1tGL|NasE zZ|3|m1UKzghra!;+ ze(#*0<#fYsr`Z}XDyn!_zd&}9zkZYZRl=@7HlzJ8U3OP2B6ok%Dc2lKQK zLv8Zx*V1$nGT1Np66wq;{r*@i*Zh`gqA8@Cw1U`v|JVybZQKmpI=Q~EoaW8`xNBQm zZLn6TB8Fu;2F(lPvIS$w@%8ofRYZd*aez{wVwz9!d##lW)9J)MtkITg{acOq6`s)5 z8XX%mhXfkXKbC>FZ^7ve&)V9W?pYbpeMg2Sj&{fQxhlGUUl-1Wy}ea;3j&5Czn=-0 zQe`F=S^;7H?Shy5-w`O%i!)nd$YJ#1pI@4nqc_W8)}*|kXGi?|Te7msJd^UXd}4iZ zp5!_nva{#>w=V~J5_|Rc&xVke{?DG`PHb9V=uSdFZ@yj;!P&EC*LHlO4i5QSoUUf4 zj)@Zu&Lk#$g&7^!NXq}aFURxr#KL!o#Fn)Z{%ow4InHSS!uspHTk0N(l}}g{}%PD zp(HV2{^T$K)@gZ=6E%73Ch+3W#Hci(llRhf=3{S?qXLOE;bglRN~J~jE`r)tp@i*X?BV%&n4F0gBH%K z;`smbQ<{IjgIb*vzgLseY`v+!wG{{M6Pt4#FJI>JvXSA(+`n(}@hKUhHZg+>H-mlQ zDBYhO!txpyU@$c zp>-HhW4X7z4m+Z(l%eL*&du~k0mT39?FVMsW<-h#3VweMD0^Dd8tYRhcXtK2?}nZd zL*#?zOXcR%ivrq3j+CXoytw~GPJSYsSQ$*&rC!L5MnLjd-lz!ORr0rOH#8t)uOZEc9z5deWcfByUlaI*jxqofnI(J^_$Or2#k z`zi^C_1}q)5tL&mV-`LfvqsC$W4-d>_vw&TjtrfQC1{*L|gJ%4=`qIf6o9>2WIXJ7a<-S83O~*LfjG3IyN>I8Wz?Jtbc!-p?q^F4bXJI zU%RxZC3eBg*4VxA`qAIptSb#D0%(e&84aG4_hu7R-~A<5&eeHu5daV2JXFcTjFr49 z7yRp*G@E=+<@~*Gl%!$V9Dr~Q&=W##V7Ny7TO4d`G!7mO+XEb^4cR?GH_ND2{H4FM z4z?l-;lC%3=X7vzfV?+I)o)b-GI_AQX+a4sFsBgBUYWO&M}CImL%_eYlUoKU2V6-A zvCRvL$~Q4FL4TZuE(F?C{6hqajNSb-e|ANTU{J^W4M3Ke$?*H&<6xQqyo`8Z*w(I> zt)eUaZ#Pg;lQPuT)kPZhOiWC)@UkR6g~84X40iVa0=3duM>$V4(>W4Mu69mzg0ZD| zxB$ZmA3Vt}`agl2i)Si(`42r0uqm48=xDkP@T@4kdvFIp_;+xDQYW${aUQ(MVIbwz zU*QR6z*aEW<^8?DlxW#wl?EGH1&IQH5@`V!w4PnU!r0fHuDcaJGJMu7WNTZ6=l?{l z(-0rVCnj((V$pAn+Fm&^FiLu2;QFchNYuJ?>Gs9WEWO!Rntv~ynDWm4esYEmziWn| zMV|$@uH{{g3ptEng#~IE8#7$p?eFh5%Hif>nw*?`Gi4FtGTP1yQj141Brf&Dk^Uk(lpr zD4fNSh>$RoPp|99fBW>5BqJ#_m6L^TPykH`;g6_rGN^IRu1c{YujUdxl_iNs5f=Y9 zoZ$Of;4{GdKc4y{TC(qR4@FBR{&yiQC;-hKD}H6QiS^Dt?;-KD^U*psZNq;<2gTzm zdYLW~aMW1v4gI(ChI{gtfiU!F6to2VxV6 zezD1Uuuy)_WqOl~pD5sz`eRB(m(U<3IKRRFJ(BZo45a2CTf~ zf4_+@;_lIhK(qzLb+z4?zhH2y}hb;(O1Nsx(-~! zsuXkKe5z=1AH4f^bgh5W;ciI6VYXO;y7&EZ-W;=}e4$`;!)eWx@d7|wWqOEKmDP9R!I5)u-~ zgv&Rrl#V=sGla!T|ND;Tya5R)OH%jj%nXK!DNS&4E3Jx?iy<$B)Pe^F_DZLLhC46U zQD{qwJFXJ|vIyYT!Ha8t`}S?t2-`x0(F}(f^P7Lwg%V%Hxt{U*TOJS@486}}9eUOI zA`cHE8J$U=47rWc5q9c`=6M)ZKYPBh;0R~PA#5xNC0H6Bjah4hu(z@-%Kpn?ED53QIrdkpZ3?fg znR{cG4Ilp57mp40%D%D4ff&D^9054h_7fVAyojAWYs=HYPXvQ&u6_JKM%Wd*I9*C+ zB;9=y{cv(ehQo~a6z8kSwW2xO!><|+kE^m|M;0j`>leqy(-Qbs4s~9Gw2zflc4rl_ z=o;^sesYFag!U-Amudc20}@HBDBTNxYmBg8H3l=V82{w!==E$4+U=z0I#wp&ko`|C z7G^V(+-l9-{XfgQtP+tN1NO31knAKkxvcVW@+-HA6dxr2%X?}dO|xR_EU>DIbukO1KK=Mt^4ZS`oL|{=jg@og z;1rwE!&M42v5fJrkl5077wmuEUbkE*M?)E*;?$_J(S=_C_3-(d#_SYF(7+R+>wOqC zH8oUsrHI(0gC6nuEK8t4xKU81ub+wRztkQ{-3~ zi$nB#abqaV26Hv9j}#yGeYgG+n#hshhd~XsauZY&P4ju@?v8+kHq^Ts^KM*vg}Hm* zAs>a!0-~muAh#-jp8+XB05KwBLJ1-NE@I%K=5b|w&a(v3vr7LU0NjXaED4tnk;4a|(hdc- za8p|wF3gB*g3OV3lG4H)!iyyRuX-`&EBjoFFi|o;J=5)d;Rb`MBk13Fmy)AG2}cSi zg@~c^jnR0|e~Da#?*8M)(a?^xfC+i{M~?WJ@F`kZS;ba!#wzh!_b`A%ZBu9G5!BoS zs)SIrC~*7HP*CV8L7ck){rOVhH#suoRNDcm5PH}sqXhgJ3w{H*3=%_oFSFW2PmdZ6 zt_IkQJLwh-uZ+s?859Y&gj`0W$Vj-(XNLk&r&3br-%U;mfEEfAeKvLNz*OYFv*(E( z0oD<+)K0S|C4XUp zj_-e0Ctb0y0{k8eq;%=SDckI?O?Pk-Z;KqUHx{W@oQ@Ml@i7C?4sVM_+?gHPLYS}%{xGi$H4}k4P}g4~eCwN? zBSQ%61{SECz`RNZu68G|&l$|uCB&>P4{;+tHbAi*F!qo+K^Z}{h^hgIicNhOq8fv; z4NXENhrAtfJ5=vNR)ctPzVBixfTHl+YvUwhNW%lpL)8u#E0ZHYQYOFxhPM+W%?D*9 zh`EkJ4xxGu7vc5 zp#HY3Qh(U?GF5cp1*`wbLke2)>dm2()PaWVLYB$D*Y(h(4SIZ59; zQ`=U5>Bpz15{qvAZhTkKU?W)X1s2w=Z=d_=^iWy>Rf=|s$eq1y z)+qVyrG7RrRDqYRuC9)>O(G&9NQ5Q5w&+7^hokhAPe zGmfojK(f+(k4%AQD!RW<@xX`q-p2Nn!1uC6adB+t=U!0Och^O$UzirCFtfGQUJ}vW zbZxCYZB6AZ6$Hf*WWS6U=Yr#Kz^)*>bj0XJl?g=jss*s9js(73focJ)O>A}Qbr5SHfyNTvDnZAYM$$qlopp{OB|@c0r5wZ1G@f-DDROZ&-# zZDn?ft5>i3UNc09sCCbyX4uD?C6}cD-)e4ehjH-aXhf}>9_JUR@`>q$WE>q!jfQf- ze;#~%0lmp%J$_6&?H)wPTvbBgEP-w;47lTiMaIX#^c11mh(L#;%1^kXu`uvI@^F8& zU22;5!)juA))E`{@47^5C;%>DP#K356ckjY$Xa1P$xu1PI`6I z4OhTmZQ-+L9{_P?JinxE%zI0>}}&~*^q_~cDn zh^>gt2ibw+k8k!6BT#o_d-`BM$8873RHEV2$aWpZ6`?=}TqKz4-JtQQTp4Oe6ks4@ zX#oUac4;XT9%T+hm=_YbbSoX^Apjdx+tGr)k0jVU-f|`RQ|HcoR>hP0`x6R1=QRT9 z!(l1$%Zf8xO}UaPX!y-Vre2=v zZ8V10Zn9Q)z7Zz$BG%#sJ3Ahrtr@;Qm8A=H2q{_zAdj8)CpSY+#L^2?8R-cgU|0i;h9e5rdd!5Q&5mQLhku$wlWG0}yR;bBb|ifStxAUvzxI6>()~1&R>z`60VC$R zM23O16OVdQ4KAqu${X4Y(mn>fuCHP0Gc>-Shj9-ef7&f0c)+f*evo->}z~mG#%O zKTSYNN{SlxU}D$;CqwAN;`{d;FG5cQQl2J1{e43+q!*)AkHzb6Gm*G_0D3MC-MywWex1vj^YPk!Jk|nYhulKGR-$gML#-Jnm44T&6MSm8a zM-D3Tb#TK#o3LwPy8c!O1bMI|$AQcs1Y#*>2H2Bw@E1|<5K)vN{r(_Lq!-K9eYbo_ zkK{K<(G^?GOK&-59k=##czRtzT2p)yHm7zwk?cfbRMBuXr1TuOieWuk5Vi@4UUO&X z#ntKX-Opgwd-`%6CWfA|!E+rOnoT$uR8m8~ONZ};%@7y#pJ|~TmMrGPijs<()5Q!% zz8z=yLfv3n8W_+ZiU)K~Eujlf7IwjAGWFZXw})Y-4K5H!0)WYm2%aFd^aDEaRKx88 zlP?g3mqyFj(ClVNYFfv}Xhj{~H36NeAJ%1Ix(Ieo8!Q287C|GG0@}=hQBlO;7yxv7 za{wk4Yo-%XItt}?Cejc9Itzy+0Lb~`bIz%M|RBOD0*6qNX)NC6M5`Qk{G*M=3Q3G()Mm_LqMz$ih6 zM_Zq3BZPbvgF&>L_;^Z?rYU*7#e20J|sG@le`_Xdg zp)I84U1>i?WB!JI?!#Tu*yKk9q@u?LOLQlf6QU6E43Cb50k8>jyI)jPAP@m9DNCtj z2T?_!gWQ#NMo@4KC}9{35U28}#YAj}i4bGs*cBf?F#3`?8J7+i2WngcjsNPVh(3%! zf<}!c94XK5UycKi9SB$&o!j!DCLdBy!-a1Fv|1p4$iPNLnUa`^dT|mTu1pwCXD!Q}PF}#_wn?^n;vs!?G zN!lkA9vUTgVt4kQx7UJcK^t(Q*We`p&m0E>G+Z=I39{#F9=f>CwjivV!H{>1GdB|o zvr9|Yk(S7Q2MP^VAt6e@v-`*76IjyI(=onFOG~UpS12!Dtg$>(^Wuf=>s#@ULYIvj za4?Y#sbNP}#xr-4FVl6CXFOT;hEc;y-hyg2HeAr0ngyg}iuYAKD4!0DJ@6^mB-L!bu zi~HZC)KjSG8(Om?FNi4{@JK)Oy>U#$xZEp?GQF`gOuT>VD#LEH^`ozAQuipCY9A|^ zXJuzMx3$SHZdpVDX1Uj+>5TW5Qz3kv+HLZ>_9cCuBnavdJu;max{PkpP?=2iiB$4l z8PB^ni;sWdUq6fGT2yG-bT-%jp!vmuL)rOdL{dRdJh_z0Ow5x+E2FCy+%&V^;7-=K z9EKaB>rkE8lI^N-g8S$zai#YI7WVXkvI!>f0Cn58%iuEL&I%>0@Y>;Yu(n6GrzlG~(}Pm`E4r=SmFflE)<` zKCOn2aQeAg^S(zFZ%IV)9apM^d&5iJmnTG4$_L^Ps$*RoH(Yb>d<)ZDl3@98VI=Mj zwjYM5wQrkIfuq|}0K!^;ISrkDDNTRb$)I#w5cav~G6)jk_{lu^?r}<2l)~Ol_uN|= zW<{=`(BMJlUk{P>jys7eI=EefM{GYMAOHu{Eo-00&J>x_-k>khoE9O9Vlkv!$YnT2 zpBNA&b$j+--%^2{jO&1R@l1^Jwg2huay;|PgwTm5I=y~~>wpA=4Hu;XaD|gZQBpu2p&!4-(09oLO zii*PQHa+_t0xB-E$yzs5<)iorZa%+_gP`exMl`^UOD##)rjpqb`w#n$5maQ?DRGQf z`rU0Wz9gtv84@A%CxTnQ8|xY?EnA0$qYO6)NV1%e2#pW+ViK$U|A zuqa`7fM^nph5HPS`j8!fWj-m)CH4c09Be2|=t>%$<`q!Agfg~=!nQMzxq>EX7G6F#xS3bU9#kF3^W&6Re za6=)VNk|AIrxbuWi6khQNvDejLEo}w*EN_e{|68~D5#bsoYsQ(fZgsVkx zxfoPBX`%uBvwiku_QD0B;>p^>HIJ&?=0K@j?}Riuf!~pTL6$jl+-dO{cj1D>ylG|3+UeeYWO|(#R~tZ2&*kk_430)Rbmj4SzEW}2`T1OWooQM*!M71b_;e5M8q4UZzW z9^?WTXcH^!NaP=cU0@Bpz$$rX&tTCRn$Gyd#Hg&xa(cT3cr)T$z`Z~7`Ew-XYnCwG zK}l_R`kI!Z8=#rc>NN0@Ji;=}Qr@|$rQ2>9quqj|n4iJ@0N(WzCpCibTPdOR9!!+l z5%?@wuK7o`sHPYlyp!^lJ_bv$=tktp2gZDjYzMG9-DsBM$zifzIg%_q4U&H$;JCdG z`q?t<6z6%U3slp9SqIrg2#DdHK*h24{kwio3L2XT1Ry{sGG&7ZY(8xPg%-+8o&DG4 zHFqg>3k@_sIWu&yKxBfnt{H)C$ByOLS%bv~DqLn}%zS{e?C9jwj1m>dc)_0>8+lEn zU2*PsO+rX02fCd=t^|#z#dkM8s3K)56j6#$k-Z+$kyT;|Q-+>pyIz0Lv+IYX(;j#5 zsfBfNTA4 z$cb>5A+oz~$2;&>y;q15Vg$B$Q7aTSWX) zoT6xB_1#_nFDtSUHM;QbLmBOvtyuJS?sLAU;jaF5zK5B`*H3Fryph{o8+!HV+tGoL z5(S$VUNMihNAg!}0B#1BQ5o2?R(*t$4Lhw@zd%Vxx57#l;ex1p02_pbIzj+I<>xP6 z%mbbGj4eI4Zhj!R&BKk(1||-en1>GuVwJKq4rU5Y_8{UrSC`A`v16>W?T^#7~|}Zt!P0#Hn4Z<`XSgr2!2&(KeaF4 zu9(u)pv7byb2}k2@}7vHFgRo4p+Xn^slaq75U4g>7Sp=Xzn2hD$yVyE3M1=ub!(mk7LGu7sK;&|R zPZ%6FR8b&>BtVM#&vDH0)E`Tp2dhLUxPXyT3R+9;a85uxxg=ZP12Atih{({dMJmL{ z=k7+sMaFXCg#4le2qtv#qR&{pZi1Vp6y$xID-$DuQmlcLK}s>x+4+$n6vxHT+5wWp zf5-Dm>IN;|_SJV9%gfa?aH&9K4mnER|J?l(<9lbiQz0<}AL6?_Uj}%)$^1x+>MuOc z8_6`RD1N*3^kG+b0f>?xWZ~|sc%?`COiros1xg#81L7)P?K+=+@~NA}?p`eKceRwo7s1c%RGPS8ewfc z>nsgqN8C6ohtDZE#HX9>!|1HMMiUJi&9|U}O>02Vl}}FBb$CPYOHq%Py^)aY4?WfJ zGU0`{A2;;ZCqJ0ln3DeR^fFuCPIuiq!g)flqTFeH*AycouP!=ReA!UEM6RHHd*dqQ zmvdgxx9(Z4sSi?ldg;h$Dbu|(P1QV=#TB)pgfGBNzz}8_s@rgtfM!=7r2_E^H(K=r zqEtlco)D<!@#_n;LKIl@`A=u19ES*t2}MO5rS0&vQ_{LA$cNbd$g-wc)<^)diC zN47lRGkJNn7SfQ?vNAug#((B8xyg4Hgb!6V=b)JlD!Zr!0=&5JVpH9KixAa7OaW_0 z0y{MXdQK>eLY$M~p|*tl6z}_4Ha5AU{G%9yM0x02Q5+rW5OQ&L_Ae?b>KVD;y#PL= z_=qeBcRU>UOOGXK!Dm8nlradhVHsKx?3&3b>>8?gs5pA{V$Si|Uqq>3wwd9i`b$9~ zSYf^U4lo2utByvERlOr{Hv$x_u~2hk!o4uZcT0Ljo>;tJB6g zn(Lbq(&QCMMXljAOxjBZ*#B3)wD~|+`IR9xBrs*3&yPXg#_sj~+Qc&(?BLE-ujr0> z^MM~f*E^3DU}2{1c2!rtXFLm`KN-0cQEIjIW{OW%Y3G6;n?Qu!1LYvL-woJ!WjhnW z(PKxbz3lmI+hT>ro?BBcE7|6{oiSptY?G*rNcs5drc<6 z-~ji(+>Fz4kG>2A0GK6-hjffzQTKBw! z&gjt7C^}&RI7Ne_2Y{MVIiC)X?p?G^RkQ3O(~`)Ip5<6L6M8EJQ}c+|DgA=%z}da> z@A#Oe;bFWNceW{RPMw^)H%~4`?nI7>d34|E&B;XLkF*a)Z*w%(I(ar0cX|DM@vLJ% z6WnYMM?V(fbmZZjeJp#C;2z5QhDq$_;QQMr4_`Mac`85tWOWbi<*8ed@8gAZNY=QM z7K_>SpHS;BT)si|yZp!(m*P8%Cyz`P?DB{M@q0N7n$%ytruV$WZ!m zB;fj3Rx@S6>jrNgB>lKtqBJR}QdS3VZ2&F|E-E^Ja1 zM9t#A*X!Q`_7Jj91d<^H7M%L-M(9X`OTu~JC?UZVu%OX72cw$sOV18v2rs&ob(=sa?0B9o!1(XxRyBfo=@REM%4*rwDM|Gs;AQ*;js1X4@#g~s*5qWH{jEoo&iD&uoFta5!wrbbxR1l2UC0v=ji6Vew zVq*Y-*5{dE3%rivY+~^^QJ^k%480ikhx&Dg=={nb_beUX>wKZ9XLKm`yes3eP)ESm zbDe&i=HP;(!nz=9d;Ocb{~NkP_m&$g-6B zLw#^O{Y^^X)yU*JvQ@3O32(#4FUx!66SY3ES6RVwy&pTY*UqPJP}@RdYpQ-iAermp zh~q^w{w8kz)s~@&Ha%UFm!#3%sj0Da^VHnTeoiQf!-ha0U&>{ ztxqGyS*ZXKAPW*`8}@;R8;Wkg^%VmewEOB*6mWtqpc_s-RxOP@O90MAjU}YGL@u^a zXGDXXqm?`5Jap55QMLw3H)c?U0naVh^o0n({n|Ver?Tp_15XllbLCvr{j&35zWoKK z?u~E^D+?#?*>6mmc!%Pq?vj3UXRW&S!_7A0ZU!F?hJ^9nAiY@POXZsTgN&9FSz_mY z+FJ$=5+DiL$Ok!*pTTajXe1Cx1sdC|b-uB>IJl%6Y(!Npdh*6(h z6BLo$_3BN7G$?+ox2Sr6D03`Yqu*EP6lppje2w^*ztUwp5un<}`N-)gUPYIJWRt?iB! z?Zjl!iF{(W;}IgfbY|1BZ}n~*)!26IP4;@qVduNJa2ZRSBCk2$LdBJ~aMjyrZ&I)5 zQ+%L98E+KY58_^?cg9LFfxn@;g(%geU^4eY(y;_Zv-!788_vN`Za9YJ+VMmJzKp!> za*dzyp;lMa4men}BB3bUTtu{jhoEIdeWL?j6w(o(_97z3KyQu<3NdJ*SOV3!v`Y9( zdx1DE2pQn-p>8vxlESIYfHWCVZRtcReR3H<_67r!$ftcah!J1l)m1?dHm3omTOo^dBvmVrORE zmAF5Z-D27uWu!fR^~i}LV>Qc(L)YD7k6?6hxNqE6g?H)gPSASuE`y?#o`{;lNq;Fii{8S27r(0KbabTmzt zY)wW4mWRX(`T;up_yv-hV~F zI;Ap*9&?H0e1wtVQ+p|WYsI+Y%oaS@j8}B`U&z@`rjfeI=C@3=s;=%7%*+Ko9O9k_ z+%%k$nW}o(?&J3fqK2n6C8sQ3ytpP#l5$rnYvW-|Md1|2nErP2Ak4LLE4H>v3C}N` zeJrqX$zI(Yr*50QwSJnTmZI4_WJmf=XG9x02OtV}EJfz6A3|ysXPXYS| zEDasqMGoR3uE|5!1k=%*XLGfECRF=Q~-|2{^zl(c+gYKkm6cPRTGiQ_QDUT2ZBI+@9%A2n8E+ z_m8trX9RTX+X>HnFn8kLd^|Wj1&CU*&mBl4!nGHRsl6{L+i!fnQ2mNJW%$+wDAnwY zrI;|U^=c?-Zy)a-SuIsgaSdWMZAqC=`q>da$QX=;w|Tv!s{E3`75&_7Zc2fqk`H(A z8SmwYZlh}~@lGRg-p2=I1cyd{%qKaP9{cE{YB~OS$t)!DwcSE59ZO2|#?zHOj{F|2 zx59FAF1di^rRm)+_F#DJZ=YDfAZ@I-$!J44@_H;nlz^(T(~vszDW3q%=&>r*;M&K$ zc1hJAuHQ@V(;UG$lqjKw>1#9=>$y^6nuXDw5dJdv-dE@7!PK@>o1IXx{#&l^LcBYv z0oWZ{`B}Hgw|;(6sY|@tdQLU^?Xwk)6|zY?)?H1R3$Me9Pt9x0T&ID2`w?V?LBJ#8 zyfM9@lNWrF)Oed!^`&B}C+CHOdQIbE^He&|<7v5P3j5vX>#cK!NB2IRcljKRCy%es zFQjrw?R>B1FUs`lr>LJP`u^r$xkXfugl*#9Xz!!Ws;QA+8w*{pg0d>-nX%e!KMQm2 z@2{z{Zk#;_#d?pjN&OBJ#_XG7=FF0r^_E6~e?**uerSFrZoujJsk!deyCPl_xFuLa zE|y)w73O0(!1A^kdHH})U%H*}flw8x>{ote3< zvwC)^*COfqh{2KNLAG<}+exZwSr&_FcFJ2Ru5O|?NF;-?3KoN^f?z9RAqBdf6UBLC#II(;Q315_}l!;mAVSJt-bF;EZ@To~W1J;9H;Y9FI%3bXO6E}BPew7K=Xs6svshbq{E7eYOx&}pU+33m1 zIul$j2vme{w=~8{lu{WveVr(Je(e*}sY|t1z=jrnI{EyEZ;JvImB}M!uD-3pl4pl& zs|nkNC+`|wP3h#k0TDM4$GBmbCTBCVhMTusRgNaOOj*BXScjik=%n2t%yjtU(Q5C8 z_T-zv`D|{pVrB>ZXYe|hgoO6l(mLdR8tsrip4wZDy{LQl1gG60$_~yJF2CX5>95Ah z!*xfOAQ_EnCg~xoAJT6cMTw&H4YdrHjlOPu(;(s27zRQ?_-9yV8b`+@g=1JHc49u^=MpTTEodvd$rq=CS9%SiEBQcO&%brX(8|I5CY1C_X3 zn7%smmc9E_{%}_y*J;ha#Y=SephnAZKJfEV*bW#oOH{>AnQiyk3?0&SA_`$Y_ zc`fEbbUS}&@W<6ZdS&RxWx`rJO1aVWvH&xunKZ)s$(>9I>c;iUI=hdV@6~yl+KvQ% zz^8{D(UbM%v#Qjzw^n;J)^NhImtnLC<7NLn=A#ZBtwP ze45tEnVO3CP6J!J`9mLuw;B5NYV-lge{qKH4OMkfa2exnft@@eOgz`{q zMJ;yCG6B!Wi5VUbk_Ra-w6#C6mr0nrLHVawNfvRGMU;|^&=*g*AfRLM!9x3s>iko1 zXN0>pQM`gLn^xRokgQ@Ffq&96`Q(2PHGNtfjg#+v#xHT-$L6JpuzwqvmH zwuji96I-pjg*pm(3tR9)f2YH5Ec_CS^UuhS;$L^)AXsX6e2Mn3_RwP7G@_ooxV&BW zE`*MU+_hXHKS1|h+bB!DP_em6bzath$0N)eOv%ia&YF?n7M7rly-Gm$?WMctvJ^=tV#KVGi)4gx8{93T_V?;c3Qm@q^tsI(a( zK-2`lu^m=UQWH+pou_x%CSs1m;{VzHq9H#ApqZ^;J+@tswmcX8O}&#;S1*CX@Stmv zWiG8v+=(6UTV^G=wF^Q+rz{V*4l6v-jK!0&<+oVcOQtZRf2(T2w zcLMh-Hlf-1#Dp%yX)k<^mpR>>S)4B_%s)%+)yax=zoAP?=2nbVrwA3%cY>q-T-f1| z&W1DCe9X7eI>fX~Ni+r&4Sq_2liPhBcU-Gl;~g%Wrn+$Bx&_tN=J!N?Bvzl;y82oy zxt$@$yUM?+`uw!;(cwD3i&#~~%VzRtD{)J9$FJIqRMgg13ZH!UR=9UCgm6qS^nPP67KEytd%IhS&nMJja`m&YYH?V# zxjQZqR}r{eCTwuf^gn-_%64}=Yqbt!u;xE&mp(ro=@ZVp*I)^gkhSY428BPQ*W^nO z-1=p3M$PvG`y9usdYe-n?nc+8UPJ7zLHj8k;a4RX3m z&vWpL_%6Q)foW*zS*#6pCO36()hl!2iO)m)y?+Pa4@=s}nm-GXQi2C!Y%}~LyzKLg6 z9a$Hpd|sIRLjdz;!}5}T!kM zi_L#eN$^7o{-bJr`Y(Hs633 zPgbxkFrPo2rTkei!dLw*b|46w)j|#G?+|RoR;YwNK+lJhhk!u^&51)?>F3T11P_Hc z$KQOqzb4g7gOgPFQhMSYW_jh-(OKOgb^C@BH;YpuOnNe=M-B0~2T1I8YWr5sP)bHE zO%IuETPw#{DbHlso5}%94f*Ass>8>aSMLhR_h`>fQ}EQNs+g#sSlH-zw0XOOd-;u< zO0eSaxPxZHqM=~rKu7$HTu*uEoud(=h7WSDuln|d-?0)8ZadELFjKY7OUjV@iC*dO z59MQWtEu-rr3@*ckkH}aymEn5vMz$wI-2FAu)H`gnaFna>McqUiu(U|8Byl?2`OS2 z#Ql*h5M(JF1Ge$;t6ss8CcDD!0H#*;~v=Lw3)HEbc=UO)awRe zn6mLL6)vD#{-RuF-09_3ZGLs7mvL@m3BT_3tJUF&>3-TP!PW95T|-5q<>iVH=!$%6 zv$#}OkE?Ls^kV~Q;iZ8_Z~M(Wy`z@Wk3#x7ZmjiO?KjsrCZ;|5hv3S2>T;c@ADA5; z1| zu+&nyy{@18Hc}Bnv-LfVuD-ASwNh(tZ4SUg7Pk zI`rNRYx^73l-0GCi7Bk~sf4Z-Vx#*dxg*+Mao{ou49KE{arP?;3c*z=W&%JPd$f^L zMdF2vVeqXvrk0sdaO!sT!?1==7q`A&RjB>4e(&N=`gWw=rh#^h@SW+-tAZ6i!4!Pxf$I+cz-WjJw+BEb;z$~M?v&}g z3nVYW-#=sG$nR_C3Yk+apPNDUJmc;)7aa5t4=2x z4kkjOhzk#<*uXA97&fxj2B*L%_m9rN*L)=5MgU}jGv2A?AAmoC=2Co>%zX&53&fe< zGC68Nlz45u@!%&ZlH33b2=TIQMe__s^WT-ej~M`Dokho;Afem*=3vzfiK%t zSL{Bk_VLvWiVlvP#_rdM;C6{_)gLX*ONI-q;^1z{QT1ppD*I(T3r{VTAs+Cly9IQQ}h2p`9Cx$=SOj>rzkjidRRihpS- zKM+)o)r@OQ?26u;LiD#ja3%uMlIm}EC(zNiKcw75gbKQwag+qzWu{tWlTLu1`NOL41KG+}f`cbiB1p=#^6F83*yFc3&EW``6jq&1e)>X`%_*c7F6| z4MkxfriFrNa@u|>+lS)Hldr_*@qFjS3dhO`p+xDQS%TB4Bo3GEqgz`HC1=je%+G`l zCtw#{J~05ZN71xb5JAZSgCXJV{QNwci@6C_(V*#gF)jfe7vPD?qIM>HOwi8I$T-N9 zx~~`^ohf9g!=pg^s^W(^l?%Oa6JvyA3;nyJVAv5**;;_bBLK{P{6pAiekS}Vcne*p z8KM6Y0EGR9gZ*2`PaTYN`(CVqJGCO*&A_O&c%9Bg4m`9AteR!T=UkUG(9{ngv;fIg zD!k5Ks{tGiKo>BdbmdMnGDiV&&svwJ?xWC-!_MdCVt8YXo(W{>+^^tynkN198DSGd z`E>QFhk6Ibu-%l)!@rUIBGO$tf43@@wZ`c?*5!^!w{YT_g05lPyq`gJIo(H(7eMlcLBHgIwxi^?WNT&VY8bzrtD_ z?c}A~b-*bC=_L_Zgnfx0;!fpbVw~EI5Z3I|75E_)+!$AM=dcc*q`J4+Sun z!Hc%B|A$)2f?s|`?!J3KlJ}`Kz(sUaychjmZ0tppvlUw_qn@o5ITW25V z{~Ct>fE=qm;4o>a z)f+lbPtWV7(r3@o!x$NxgKJh+bat5+SpW13OiWF&FvlfaWq^%|hFbxb8VRD2)h-h7 zO5fgn0@WDamG9sZg9P_W1G=*iN|BTrXsaPep$UYwRv=!j1ARJ|7hdnQ`O8pp7@1-K zhGW2@3h41TCOKiJYj zC}iqQJEsVf?ZCw=!zC4XcU(9f$U_&bgGYft4N|89li-f|bZ`k#SRrHychI@Jvv3SK z`N93k;{hb2P@Qk$ibfJ#3=2FFp~uwAJ~gI|bF1r75=C@ht;_3Pl0LNK z)RECUM84U)<62c>ELjHN)0U+Kc-+t(y&}hWf!beXk%_up8~0z z1Z-ae#)ahhJDnI10T2Q(GcYNEU8S$b2g+S45Ro;3*04}7RUDX?z^fn5{n-FDF(?%Z zUG{;khOB#l+k=H^kE9KNHXs`OiNsqNX0W;f^{5Q+=aAbX$cwQsNL7YpXh;kO{KBnj z)G$<1QIT_X6`~V$Ai|*jAiP)W!8;rvzy}vM&W;qI9yR-7;go`7H`vZ=ev&{|X3$?o zV*-E*re<4p3ZqkOLW`L820S8pD*9UAc)_a)S!`rxvgPZSL}#lc+3>uD`5Rz#@noA` zKhkbz=NZMzY-UHvwKYBzrSvq9EqxBZIRO%G?4p+){hfu~K;ZNmE^u~pJ(BA`fBLb_ z>mSl}gxL=F9 zWfC%2#soR%1kY^p9gR!$=Spb|N|iBK64^QXVrJFvfP2$BA3frSHz^>(I8GyUoea#1 zu}+^pjZ}ePAwKe=I3j}3&d!cGuCuSmShT(yEF=tyY}EOZT6@(iK4thfXJy#x51DB( zlQBak0yVQ7%yF3g@F5u14~));$s*cdx?@4}PcxK;9xFLhl5rb8IBG`cfz(o ztq2OFvVzr=z@%Sn4j5nO;0lcZ|eet z5Hw;u;75_Yq(TrlO+?)QIzDh-^!5w!_2H_@D=V}xx^yYBr%3>0m7oR6)s)*{sf7Bu z-u^TP2Zz#%wY5fBiNl}BK zkVd+{`JQv`y`O)a$45Pz&1Su8%{Av3zu`qsGF84Ad_cU>E0T;>n4>2s!VT7fK+%VI zZf}#YQOjhSh7$Sz6mMR1w21$l7}c+5_SPM|qh z(6~O)@Z1Z^X?aR}Q|}oc)BR5XYK$DO`p%x$05Y(ve&J#j+@agE4FRM?5+f!Qx`Q7p}*#{ZTG{zZj z|Hx4Tvym#2yio*XIM6+`8oGu6t>!yYID_RGo-Hiskszlc+%*scbHGGy{54iqyrtu@(lF`QBm)@5Bd$@eBjS%|l*NMYYpAZigbZ1Wqy{gnk9zL_dUb^&miH?ZqG( zqt>pl)#F%aO6MTm-oa>%z8jqOZ+mfZZ_Kyi{l$5y78OOQB@RQ#xZ{iEo;otR=a;4A z9&TUXquW>rpQ74-if;DMc9TIYY4O&Fqn$SXyj3!(&?>EqP7+oc6E?+Zf=3z$x=5wO ztI#m5J=|wS`Z7^(AOlCZ#z?W?-ZG_by9g0`3fOiL;H-$tT9#|oej4O9G0FVtpV?7PN7e*E&0NTytxnSx`B%>7yKEi71`lYqWcaJL{ zcZYy23H#EQep+5)iC=j>1+!1lh>Z9ck9P2nMNZIBz1M{gmyRx53BYjfQBboY1KY*K zy2UxhOa5q6+$Wu9blxmt^w}dLKPCl0%M4>CUXuTWYk3CtobYU6T@ca+(M?P>OJF>h zIgew#!D4-95P)F35JQ8Q=#stnKen900(b^H@NljBZOHWI7>q#F_;5GL(yihXupah< zO=kv*#5iyoJxLZP($UcgC{PA>2}H6y1FJqug(+-FV6+3jc&p%)R{*%Ir{OV{Rl`Dv ztnfQ19MBLU6B1Vd1xPa#Mb=Mi!pQ&AF#(qI5D*#&Arn9flES`lnq;YFlYtjW`=2#f zsPgA2e+_2j%+IB5rS=K=B`@s89G8hl_sMjjmvCiJ*BV=eHV z(2w_&d1T)()4E^l_FC^^;{NwHnV|eBOSq5Ax&a?jTwAP_P!HyN938{FO~b!E5*G8L=DKmsav%k zue^j<%Fb!8True{57T?{8Jt+llXW#c?N+yehlEi2;Laio%XE;2oJg8%xEbh*k$f_Z2!nK@aA82RDsUsv+UGKHDX8^g!F^ zV5Y-Be&*wi2QS4Ka4u>B!u5_DgAD|1_2}R-4+mlMPU)3}Qb{unud#ZXV~v75Pnp|M zB^;?n##+krqDvY z@i(J=E%b5}vnxmL4dp*sA32)YAJE4Ji_GLPGYqkfPA3z2_M|yV^PENxofogzJP$6e z=^uUdwBPN6uY`XWnpo1jp^lzsdg0SVkA2LEoPk1hm5#-KC*>NG196+~+W6SC<3Zlu zA=(^HrRv>UWYV3{!z@o-$R9|)xU*RFcWOedT174E-7RUu9KtR=OraZ3WHVIJaXyD@ zjiKHr^DGvak!-s=Vd_6mXU#kk6`5=AE?-s~5mDsUN>dmVAzrYl`9gj}`a8#en7ecn z*!AMd5|oUPRC7dGiV)i##bA(Z&c3@iS3Wu*Juk{YtPKg`CZnt)4YAci^`&war8vMG*@ye!n(6g>KQ1XWI3#=<4 z#`LTg6KC0rhs%Tx5-LN24)+Cawn}g(e^w;l@yUq&@Xxo*yRs?4FRp0tJpEb!l&_Iz z=H!q1hC%apWf9WuW40Eu4W^;)SrKFR6odZ<=lwAam)q!umSrFL%n$z%SB1%7;g-rd zA>xv25$s^|7STiM?1G!Z=EvITsKcQ#Y|o_e+f@@m&eJn`Fqh#RU7vjYEaV_{Vk0lB z`j1Wtt+e-}Z&WWE&RkWFHtQpc`4>Z49!@94ct*}BeItKvhVXCqy`GkK5 z%5ly$E!dwHC+&VL-EZ<1SX#PoVfn6Xa@AMNGc>rSB4uc1iuZi?(yqvdy-!XT6tG&DS;mv8kjV_16v~sa%E!J!L;?kKZmd|Y-yr6$G zboI}3y-v*_JdrVMRo-@iJgp8bvN4mqgKBB5Iw_v$= z%T$}~`GfPk#QH5JlCao?gxZ+Yt~6^}{PJdkhZNOo6Zwf|zeqc^XF~W}Nma{ECOH=x zrrr$I`PR|DDN1v(t;j)+iK{{sUvk@K#VKHWljSz+%oohN#YX>&uQi`X8DUGLYd*6y+mAoPR8E4;(zJ{C#`vBJ zRF0i!aeKA(zk2DGDNL(UZ@DGH<@a@+!VhQKJNpAar&L~iyb6o#n{_gG%SP+4%7duu zots6UIb2wEsoqIN5vbZHmd<&$3n{LZQC`-;6v(* z)8}#HGs9J-oq5o-U%pFpi&xF8W@7>yh3j^Aw_qBF5p2Soa9*@fUen_G!4SFe%sW$` zbgO@p|H_v~MM6@H;ojRsesz+e93Rhf&b>)CW_=%!T&Y?qHcSwDGNhk#`${OqbAxd2 z3kDCT0{AG0cl6`BxB_O9lIE1`>k}8kJPL454~Zt9FW}(hoTScED8~o`x}W4cECE>Hi)ls`dFd*zUbyj(S_d@uan7dXrKS4HeTjkNYqd z4V^wPc^7qg8873ZTA{BVh7nn#%*2C&$mT8Pd<<0i_!*@ZU-mEmFyU`h(*AZ$R;eF9 z6f`-Ci>=sQe^s)VG_>JAu&Zx{Rbkb7dqck>^V6rvM)@iJe7mIbx$Atz3UF}kc5?qd z>Fl=pXq+&*KQQ>y`)v1=HT}_F&$3&;F-WkM_z&^};b{jk^U5Ak_hH08t2oy_{eL;v z>5o|-$EZL1NcQbdjyi@Pdx61Bq~OSMs%cUp+91Ea&z)E7LSFlFfX@@C<;Q>JANULuK!`ryE? zj}yJG2JHK*4ea+R#TX@9^7HTi7`DpGVp3}It4{fT{`+~}cE^IF_Q76VKuu|#%@!Xs zQG0Wn(n(aP71xI%KsQ0XKbCC2e!xMBq z;rQ{#fp5Y&pWYA_+F23qto6vhOVNT^Ke{N4Zxc4Eq4(@}>t4L!L|$%CHOpO@eK}II zd&&2U7Ka=0Ut1{~{LEgbqH)Yj=;DcUS`^aHt<4N4(mp>nq zy+{GSw12^4w&tKLcDHwlFJIf{%l&obuT-F?5*iy_sZAoJ!?8JuH~wCau0%kCHpk%B zEzlL)^&kSzST5Hm81x2aYs`x2_wk5`uu%6OSe?Ip&I<8#DAYg7^6Ym$lAS$PC{*Ow zRKSw8f%W>eFYoGCUqyXOUHv5lqH=^7UsAG>V6_0E14LK;<3}53QR4sT4sdj8SWT1e zZW7=pE0P#rnuC@U2oP{f!*ZG5{X$8u)}|s_9TYORT;dx%>{OnNDcy9>ndf+ulTbfE zGYI>^T{;Q6`X`-1Ib_G{4GC@+%RT`%XL(DKgn4QWN;ZEL-_7-QV&a~OO_w@c89r1s z$rpW{_51$r=F)Tfgi+fS_Ka2hvfcI7dK_1t`GP4T+k@nL7Z}N{?5ba4>bC}{RA@qG zBOe@u5r4oh|Fh*=;ee6)S{z%NT@Y`pBinU(`wk8<0rovjcGRr&2gXTqEL36g_~oVJ zA*l-Q1{BIegQzJg-Z_RmVW7e*^PtvqbgWt5Fs(c+`Q7bt>vel}?p+QJ?Cz`R*_{_> z=X7~Cm;btL6_&JG#n^7W)Ya`7Zl^q9ILmY$gG-qkk4w&Aj=V}&O(?gYCYQ?jSL)(F z9(sBwE7ryVcV4B3a{l_VK8__)GrQ`d$FutTl~l4TQzQ;+uCTiKv*i=oZdZl4m#sX8 zoNn!deVz1pSW=a#*bK4XQ6lWU0#j#{j;URmPA52?gx#ZbS>@f|m2ce-emnP$(yJiO z_Yl7o zEo`P$D=|>Oo^*6t!GEjFE@s6eG9ddT=S|H|;$PbIZc5#*r|zi7M??;Vt*FqF??0!} z)Wqsa#ebTHd&UD@>7Ktat6Fb;XhxU;&aYoIC{&l{YSZu-g-(s!pZjQ+cyEFZM%q8^ z8Ekuwh%|`;8e?EhVE-9s*xc1~m9Tl*)Fa0szMhSV0q4O_O=3LR@lp<|Z@;vN>3lCn z>}xLMWS8EG&1^k$MfuH-`wmyzW3V{8@K@%GSHHXjKM0%6(8m@LO*Fky^dyt}5us{T zRaRNOZE-vqSx!OnEvmj~9#ZV^NoEpL9W`6?+@&8b<+v{J5p7hAAcN^<+7Te^AucvN+8 z>ngT`nDQ$cMcfdlSEEbWiWDJkp7iq`5z(KA_+vhN!kAS%kN8{UdOzoW7{|xQOCs1) zT^`?l91C>1nPJr4B2^L{8=kdwkeYE#AqAt|hfj6dSS3|zrLK*yyZ1Hig~_;kcMdA1 z%>uZvKJ*RB%W4rIc2OgGz8Hh8e>v>XrwV*83<8>#`g&>JxlQKKU%wmvkWH_lM_}k@ zWab2!;2&3`RP*+GTkn|%6t$A%V6oG?wCsJ|8GMN9@l$V04$8AkeVW3J>a$h4VPa}g z8(isUuJHZyySLgk*EqvrK6L7_j-T=S2xrguT}oD;!nSDe5<>`e zS$%oFV7Tz+D+j#&BrhCszY5~JZf3>i$FeV3Ha%za0ya(rGXldS1O&E7N0?<7&Zy_t zUe^y;&DP`&x`+NyFu^gS(b*HIs+{$8rZ$>6O4axk(_dJW#q_q``RgTtMLP}eu}z2X zKZ?t4(pj{k~20;KAQ{Fo^Ij99q#Nyc?=x;;Q%5nJ?^7O0aG; zG}V0Ywf;r6+_T3Ijc;9f{QS$X=Jd{cQ$IBCy@NQzUxk-Yv8A{n1p#TJnjeT@k*1ZKwLk0HWStXKi6%H%bm1k!~2sarr#zy ziHC3^RdS2Lbpq-JEBmvTn;6D5kGMM!=@XO65##6ojR8u3ykbx%;30A}EqCGWw%Qor z+ZVqq;z%}WlFn>C$M$nqHDS4Cf@OZ2(b(BFKbAFyQ6*JRzmbPf zBw`ikg1I5i>9RV5_sI(5mD$!O-aHrM5|%R|k#)?I&&^iWWqYnwaD=a2Q&?K@SiC_~ z`T6R*j~N02nOOz|^R)H&!Wh4+-8(D&OI|lplZ$qJ(W-3BNQ-H^12}|ED@Al{TlECC z9{-T@$iEozJM{t{&9W+{a@U7_Lt>^J~`nbkeg>R^vy* zRD;d7jVG?n+glx*oxNjCXm+QR&}#X`Wr29Nn;!0nKR&$6!{Vj!G2wPqjF_(a?)*QF+4`+ejA&A zIBDLjLo*~Z2|(}J?R|8*p!J>mZN-}XIbuHb-g=yYxrXd~+0j2F6oU7?Ma>^q5ahl3 zGWKapBJ;8Rk5f7U6}F(EI>E|HW;DuSCsW2|_qHv4tMFw1EWEGugspDp3*mUyzqY%t-=0L$!q<+gYlh0{Cm{nf{}9 z7#`ewV|PHg8C+7DEX-g*)G%A(gSEVA`lpGG}!Z@njY@6=(gkBGT&VC5xuyHwnHh zArwjVmVJ{~q!m$i(JFT`wcKjXE%@$4k_UL7QRy0qP?6j$QW8Al2<{Dh z_nMCf8^$uPp_7bAMZHk2)WlkRz6R-gbyI^-?RRj zLuSI2rMH#WHFYm;pq*PBv)?{s3GN%UeY_)VsBD7IK`wZl`Kx=j>{ys9f@WU(AZDR( z*{k~DiDg@L&QQEmoI%vXO=t1k3-uD;ctV?i955GaRoQZOiv5wl-Io3LT3s?6 z``eBo`ndKE>IoZM#kZ>zCa-mHHL#D0#+_ClAHw^Rvd-nLGi`6djtpY0EiSDc`bXWM zzUh4YY@$V`p2g`*e0BRe&V)~GppABCZbRDTyM&24(#hMS;)*)8v6{QCA)c|cVG)M0 zGeR693FwpIRAql+>jzOOsR_A@cP>|0*&B3aVkDnok5n=+C7-dI%=uuBdEST>Qtk2N zyehK!&%3amxwHqT(d^{sG0N8M64!uhzQJ*Gc1zKrY(TNGclT$Op2$Y}bzijM*Q@0` z&}#{SNJYPp_MxI8PO^dc`Cty%pdo@f-&r5%Jm7)P6X@tsF5z1`J=#HB@d#5H1~f=& zFW@Eafh^?3(O*+HPfzv_-n+|}M$2soAYu^VSAp;fF-im23V|LXzzI7~0KRtT>xm)} zct{u*G|!P72_%51(({08f9)58@f%YuByI~dTZrBzLoqo55__hBLDUFe7XB2?pYg(K z^k`?P2T)E6J|?E7&tNoxWIX}%I=`F&di5A6M1;3FNq+eBap=`-buTL z3&ijiWA5j*|H2@?iqa*@U$roc#EjumKkGBmbP}h-IB~&h9-`MTdS!uJU0}>Gof;Wl zjGy(x3hz*eaPqZ2RjkX%H8rxykMV*xJ1(LB$r)50&nkKL`RfxkrE!ij)q~=HeiD2S z>O@798z{EUbsT@o*|3QS&4#xLrTBMp{fuslipr~LlX}q{*JVNfMrNsxr#v@zFh@fM zcNgy3GSw`l6ItrWmv1-{XiEIu8v0N{`H5ips)F>=Cj-L_-lg%rN%~h>TFu3(-BZGu+iu!YrC5*|;7pgEQI0Z;k6k`eY}_#v5(G#aed>ch)IM>@`xS86ViJ;GbX&LD4RPw(=v zPcj;^`px(lCUC-<{_jd}es%Clt5LEMw6a&cBz!e1ikj9bPVX)_1{d&q7vyTnY9!Q0 z{1y{1D0O(6JF9m3sc+7L7;eMUk7jbtck7agoXa+no353b8W*ySnNX|UGnZ+&W-!;X z!5)=XpL*?3r=OglILWeZ5Vd@FbGmYTu3*Xg$I#x=RUD_2ZL;bMUVtygw%I%K`C>a zMIk;t@f;+qS?eK^jFJc@6VfX<^r{g)3BOl2E-&+Tk=0MlXoghzg1b^r6QIj~B#9|f zv$B3oDcSEaF=y$x&>exmu>mX6!@Kmee7!S;hTQRb3Yf{y7^(Xot^NhpO-t_-n&fJ@ zrJrOh;meUi=QQIq^e(B%g$;>8@p#)K3U7p`x>}|1zG=1yaEKPDXwvDAVQhnk5xrrHWG3M8C1Bd>-1StNY8>s`2= zJ?&m4Is8?GO1#VPFmq2xsLF8Wqlr_^d;b=@FCunP2V|B0m*tu2_8zagoT%nzzo5Nh z?6Fs%PgBP}q%=06mctz4{jar+0IkpPRf$^~?TDjqkC-l-wdM^^4=`QS=Mf~gb z(07v0-+UlGT0Jo<)VvatQ#)81EQ#`HOYAOb)b#yKe%83OU5c%E9`WR7?62L3Rsk>WFr*ut>MwIq?O`h}@u7rsMZ+$Q1z4ny zh%wJuM<)5GZTD?IpkAdW^dbu49dcxhQ&(%9oDYteIL$G~^hN(HDcs;`$UiZ3-zjiZ zC>_n~5(T|a-p;{Cbeb2GYcIvjKRcXiP4zAFrd2eNmTv2OT;xw;X-(a+5UCi0XV0ga zqxl>0#>zVXHC?^Y)Vz&@7(SDL4ITzm%?fwnL$RF#8QHqB7+He|z}hZZpFNS<=M zeuU*}&T7KSHer0C?q?`#2LeNNyR4HJbu2==YoA+QXvVF*q*t7x8rR7C`?7)F*C+T5 z-V5p>GSvFgL<>CuNgcZ3t@(cD4Z6yW_HL<#!**}>M&A?y{Wd2OJkozeS&U}|dXIQ% z&7{nJeYcd7cEm9akBHzZw_&+)0}TagSpLNuu6np5b(%c%Pqu2ytA|JLKjQWty37GLGKS#2l9`nm*?kI=N*CS83GWo89>Cl zou3K=tsFv%f?&Fze*HO6pdu$TB5XEAU14B|1Slc`7y>yc zBrH#Z?ZKEC&J*yxSG3oL|J zKth%Q*IDP_0lKoN_wNH?K)-y~F6X4wun9f<2nuIukUk?RF-4kRceIh31Nm%_Z~_Q; zwA?{C9L+p{a2I7Wvt}{I`FFvOLipFhQwJjNGnor~| zD8_efrP%r)kQ~6BAq*NrzP)cOswbNaHU+t(EdQJ^$D+f-*TbD7Rf`5Z`{G`d<9vwq z4M<#CB5rjLU|bp0+^nZaT$94E%{rDqm{* z;a!P#!7~(GBjA4>qgRNk`;)%kWyDOIZIpXE5Tg{6)P2X}?k|7qMVFuc--6KGF?Cp8 z4yk$y9!hTTa;eHtqLRBLRTHgm|9A3D)hjUDR<5O?q)sFn86O^Z&zveoZto=%e z>KH{6#?W24`tlGvrvF<5jH&y_^Sg-&E?#2@ojoJ;-8S6v5KzxughB+ZdFjvH`j=Ph zVr&-Xzjiu1s-I;G5L*xwk)E1dZoxBgh`daruVRJUcJ5g7b4aMK5N%;fDq-e2qt_Vlb?H|musOgf0idsCwM#9lwxx$_%~n{H=pGnuP3)v-E2ZHcD?6vj1{zESdfb^?te z|8h_Mxq-1y7(mvz*6*8x}$ zj7#X(gBJ$%w}CXrvL`i&Pl*H{zjs7H>bH#6rL0wK!OWU#JU@Ssx9jCFd-gr_>gg>~ z5;PCb{;s|8l#bO#VhEQJlb`L+ID^1s<%^dihpMGdtH=+1HMbXWXA;j8-4(Yw$D9?zoxLMCmMidqLZ~w9!o?fR4Hun^I81eQcR6 zGPmOvT;tNux4D|JPFim)T6srrVvx`N^jXRvmz~U+c1mkSh(qqX(W;rCP!-R0we{-H zCzpAnWxIN)TS}*#$S9jEs@xCNHz7v0Wdg5UAt!C2tkrT!mt3)(9}P8;9Lk^DIXkx# za8vL5)7sSvTHRnOYu)0}bEjlMxy`b{vsY9W?|d+uUv#T?u~QG@M=#?T8edj(iADoy zS)`auL-*%Os>UJeVpjBmtM742fO zR@bxY&d6F3{k!zA_TiIz$m2vmzs*q?qnaK6t&s{Xj}+-m#>y?vpY zo!#UTtLr*zT=+$bX*<)S6}2=SiO@fU*6{4n_G0(7T*d6|Hqz71@oiU%%c6WCBWV{re!Oo$T(U0`4=C zWPv!$!omoVgd8Nm7J4>_6dt0&&{3drCPG1hHQ+|{qZD3?2}J%?UuO@qpFZ~7TZx%; z2?`8EeBDBx`y61e{G6DN8=*F&p?U>mgcQ!Ji&Iq9@NEK37lqoKtP21V^6czvm?aGi zDTU;?CHl3M(_e=x+%IrN5AV(?QzA6t4Uhg0e7VJB=*5L z9dc0u03f#9j+7Uv8fBQ?*=Sy@J+w;|p{4FDLu#QT7=^wNHQWs6c%=M2QkK{8 z1yl%4LYaK2GK(Uv#jTt--`}Obtv06(`>Y(b*aN#S3D=L&#ydHJ)<%s@sal>dFxsq% z_{~Mt8B(rbXiE#|wQBNns1VKcx)U~1*Rv;8IdK=)pAhKE`UHKl?2awkbq&m+ox0Ls zY#%3z`q;kzXF{W8;*yl8PU%yv583m?8+x*~TbV?yOlLcvrlW6{Zmxc3aQd)~Ug+6C z=SE*8_A^V}#=(K&F3u%Qv9SRPTDJ0EmBQpzFMAk8X$)udOWi(2H%!z{zZxpam!(V# z6KvG;&7ntKQ7^|_6ZAzZ@xZtcS0TWxyzT0Y?OS|xtcOT)DOH`Jh$DqeAg7*-O=y|; zMQI@aBo~+Y{*v#s#P1Vbc1y|6UkGFeNL9^lU7MFKw55$Z_Sa@d-<#>2z`4$sXY;mW zoOtrYrs>>UW!1$ULz*dWf^r6{lcwf~;is?A$CD%tBAOlj4-@pCGw-=K89d4u+{t&< zQ>kmmRAQ`s?TK^rl}?*d<6S(~`4BUJPKj7pG4|0Wr%fDwN}S_G6&VVi6S>Sy9knk7 zHTgx@>nw<6z0<^>Oz$)#(Np#%sTM_lc(r<{eGD##@q+>j&*D`v3~os#>>rVc7i-W< z$L7~0C@lySNMIExlZSU*l6;`{f_}1lko5OPuJT@65$!yKb$^2%3O*ixPExN`vGdp7 zHkSxDWKw0fIy3^iWlq;tU-09JwH!Q{uNjj2^;=fy218lHQd>@Oooo6q(h6q&FQ{-3 z{F<4k6`C9nM8^X5A28k)TJ*6%G81bJB|Uf^7rT-p5kDb#5>bATz`zXY;!v2W)OSMu zf}x#VBmxq*(E)!Ozlc2{2}vtB6c8&Oa#e!pB*4ncz+DcB-T7C(0+HLF;da}$p}7Z< zu94b|3+wAZc($B|2+uZ)5E2G#up)^#=V;@OfyXj1Fz{B*5T--!HbSU@9#^*W>MS_Y zA&b-S{PYM6S!Cb7eIuo&#;2j7xi2f53N8U8Rvii|@rzSO#Agg|f{Mow1pzfpJYmY@ z3CBS5mRn!C#4MWv^A}{3L$cuR3ch}=B7@3)#7mI@mO77-HEXenX*48oN5FPG6!uXn zuo{DhvK20a2xS>)>Sg%DU^aV(wBRc3nP7fH3L7l+s}aIADSi;-l|m`{G9e*BYg9J0 zD(>l}TfRqfS39rlVcQgNc!rba+tkF$t@9XI!OU@5$b@5pV62mf-?jxmK@F#?Yr-%z zzHG5@QG5kY_sX;IEFI%AHldc7$X(}Eri;xlP2LMn`QG+=QLJQ|P1)iBis*5980N{RRB|F;aX;G^QxrYsJiIKBOFYsnRJ&R)=c@F9GL%=|=fRoA z%22$yA$uAUH=X+` zlfHalc&Wg%;i$68@{Us|$-Vskx;L>Tv)j{}rj0p#md3uyq5SD6%H{bNuINv1J2uBn zEH%;BV`n9KGH6IzS9-`VZARb|Pj{^sT*Y@2UJB-Q5F7ZhNvQhR0?+FGnhol~TaUOu zI%u^n&W-2PUoM={I#T2=y+zFU2Lz%C6?28hm!s_ zd~+hGF=pqGqhKnL^I$h2v{!7K@4nIcs$4gFY4h=qX_;%{cwlw~D zyMbFfDMvpC*Mw7 zah-Z;o@2hAoDj!X`=y=nevIc^+@{Z6ehIwP_r#m${t&Sf9j^%n=Bk?IoQr9wL&Wdn zmR}>M5^Xp5>>UQDDRB@j38{Saqc4Yl+qsdh$z9)bU1Og|k6!cH!vk~+*rtBKEdVne z{I?0R50PNd{GHPyM@(#N%Rq^wq$GQ1XZ{?$*Z2*gp|}g0-Z12rg3bsMaR+VVCa534 zY}^XXX7he4%Us)BrO)U<6Zw341l=5zAM6O`96AtuTh_=6L{*d$I9I|=$@`!h2nI;P z@Yq-yn2_5Z)k86)TA)t{I8~&A0d)L(e`BBnO#iW%@5qrlTr24+gv^S6Vk#Szd-N=x zL@W}K%5-Ppu7zo-^EVD=0RI94J{h3N9S8y3B$6ANZy*8p8?>SFD9%0H+@>LA5V@}L zc^P+$ZHiwBYmyxh&SF#LYeoQ7%9$Xdp5-{0Ki>s{1-%%h;UZvuXenW@} zr9q*CMECAF=7+cp-@4)ZNU4IWLnTfUv4j0*JxQ6{d*Go+^&{2kBUQh%`J|tt_K(H3 z9Gp=fTvcAPcRRh)SMX(ZiV>D5R>>&X&+x_fRX&d(hY;-o9^%>0t za=ytYyM8oe6EY9xl~uogH9hTR*ulM)MS$F31CCW6Xf5?tw36{8i-mHq({={bB zO;tcjn^7l%35K$psW`r}YwWfy=5!Zthl;o`r^0%_ zw%y72cBOk6Uwv;gr;t1)3h)7793*98lFs-3b~&|@2KjZfY}-nFrLeiUxVTVw6ijR2%e5t;kY6Z2&eS=>t?!j~X1MGD# zmOAN>bfN3lG3n{)!SICeO9oCzg?Z6_%1RhR!*Flix2L~t?!nB_k;5eHOdG%AXTEiv* z>fqth0)lOBvoH^DqUxN6zr()WSo_V*&jhbO=G-E?lb_KOlgV3<`&BM?I!lnSPi_5I z?mTfzr|c19QTcj?Z~f(~@;({_(F893*_8sSU9ZhGX@<-D#PZ|BqX`bsIE@+NsUhkF z&7)~cjm+ENph$x!wMMBf*Q$!*4RajZVWS_kI?Eg_Y}azH=gs6?KdwJ7>>&gZUJ;3W zoYa^wN5uNdybjvwHeD0LR1Y6hV$%BGc-20_jBl3co%^!kP3l)bB;tnkjq0e#m$R4? zYyW^?x`848s?wbtbSjlsYFD(;8I7rQV$_G%<4>zuX5Fij5@%T{o5a1a8Xr5Oq|Fn0UFwGjO$5fx1R3HW`| zir6$d*OL}u2td7q|3pRg@nnswqOq3D4U-tQBLfdue@!oF0g(}$t5Lv~D}<~^2=%7D z>(&j6aU~i+Yk`57h-4s>-?ASUWxqWu6TZN$SB(j46wVHG)XHEU0aSCJ+W#JTnJ(lq z!9xRe*8{NM!KKt6YS9^JgZ;xPDTcfNcqU-Z5de!ZIy!m_i2F!lbDJr(V5heqw}^x3C`E2m{-z5PzDvnCtPX+*G37G8G{apzCsCQ1FL56@hmX;ECN^c^bu zu<`NsZL{ZHTsyySJ7)Np##Lk!^c}~%^*EO8#4E54&015e%-;Ck#(2;SzXh1*3hwiReHf7z|U|?kB+}QmyyhOw2iR(3DMYzKzSKsbZ&gPnsd?p z`h1OMvqOQw+_&gF^hqx$6Tivfeeso{D$9y_8^?RUSYr%Ew5uJD602hDe#uJi=xJ0> zzWP#ep%|8xcW|-!y|+5~uUzu=WTyR~jJ?GZ@y>|y>xoSJzFa46^(^`&NBRq{1{nPk za}`CMw3mH-Bm7?+qNq9!ug^`*MJ0-@gfA4YfcQb7o~FCzjC)?DRmU3^4FH zl2I48o2=C>;}jC20?ywdp5iqIDszLk5OVsj26#Kc2(D3#L^Fw+5n+ikhSjK_L4#Q_R}k^~q>Iv)@#K-fY6u~-kBb}ydTLKv97M}GQ>R_x&5p!q-!Idp?r zXnYO;95h_{f*GKM&+P1ySNAj0(~)3bz)`4qZe+4R7yNxv$wzll%c1;m;CAHLJb*|) zXg6Q?WVxyAbK6cCr6pJWm&`$mGTzW=JJqwk>8f9FFMhif~6c%xXut_O0AdF=)?$wIV5gGo#dc zbFA`&YbrWd`yE<*h1!QVH?G{+Q%p2J(cSTQ8~m+F_M{dci{n^O<%CoAswY*~2Ab*R z4p(xTYxmMm)1S4_J*m>Pt7kEBs6JsA-B)*ilxRN2HSkH(a@DnP(AV@rF}X`A)BAI) z$4;%Tr&x$i7YiPaNy&xYqxsrSzI~qf0Zt9?amH5~p1Hcd1%9XNR-^Sg+Cy3`;&TT~ zkq!xmDpa$ZuVZaGUJZIHOI6hCbxYeVQ4(iA^5|KZHW(AyzYo8T)|*KrzW08S?*UW@ zCRmrD--Up1fY%ZR(}3)Oii(N{2t37hLIlu9bGZfeFQ2lK7@a_@H>@NqBpVGHRWmTq z1poaE%>N*#zF&Ad?a@!YdatVh10J>K&q!}sZSi!ofa$HIUCoB?x-3=**x31tLUTPBhyN&SA3}Ng$pdr?Z z3`qUqJ>I#W4MEh%I1`u&sc^Eo`}oX($PbyU-hpgvh%}UwlWT+r!yW`3NMw?wYkKF$ z+c+KZ{L-M-g06oE7@HCC6`4 z`a&MNw~!GiRO&xhAjt8+yf^(-DIJVa<3wG#5d#yRLH0JFX%tZClIdp>I;Z zg^$cbtC6VJk2D@=MxMy6G7JB8;$umefBM9UPD-flO~i zpQZG|^PwT9JF8EG=r+!KkADkd z8Q#rjMlFloH+C5_-D09bt9y6N-RE47*!A-BipZD=>8SdHTCarmVYBh8*vn#eE)Rty zcg{8X_XCBt4EDWeFFxd1ucbLM4E^ymwTpT2!A`(i`Sa(xYTAb7i6I9|72=0_;U|Ix zUETd@ku)Y3W+i#TJ*C`EboJGeA(NdC9t%b5JBb}~=X~_Ib}cpa5w8F6@Jny{kQ_R9 zUe^s>O-)V1_E^qg2qLSktJ{X&^DHbT>dO(5ErmHbqKu%-%Mq9uAE&Kx-E0{$@C^sG z&@)&{$e0N(S;%fmKtS+mG{Yle4%Uu|>zNvKmSKBpf#vzvOE856P)7&_3~6t}jSv@s zQUST52p%&nNCg|Mu!}_KOrSMthIm~R3gB$f(QaBuN(!uaWh+-$vrm%lT12gH@9Z?g zC=DUJK_No>=nEO7#>au+2_ZN9d7>N>Ohhw_+`aY1LCvTmv`uO-Dm;`!pFI;wZb|^BCAk5yVs*gt} zC&nNdJY27sGBh$mhv+?IatB33z4ip8v?1*ySn2P|(*PK8xKXp&KPtIfZsNU9mvZv( z@)(0efK>MT8|)>bztUD6VxE-C`JOjCyXW{wa{fE%!<01>K}=(x8^Sy6c<(QJcM?#| zI#$pR+~ed;vUz$vvrG50pUhOo2+OP8hfe6DmsMUD6H7|*eb2QHjyt)S3C+XyP1)IB zDb5?YQS0z1c7e4jUfQ9nR)#K)zu;4D7e%*EQQ#c|V!U5#WWLAwL+P<*5$zrM9XX=j zy>*ExrDYxEbUDYta(7va?{f3D?yba=7V-=v#kIv>>1Wzb-}`8%wK%}8l_qU;zt5-F z^MZzvHPy!8R>Gj<@oU_=%ZkG5#@jOTd9O-qe3Pn%a!KQ*3Ay)C@o=+ z*N8`JgcWFLViNc$hAms6{jSMeO)8HU1)Rq(VF3b_>`f2><)Z})3k#HC_h;od)uOH& zDNnyVq7wHKk`1Fmlb4s*)YhhODltKZxFF3~YJcITtxX*k7S){iaIkJ=Eux8 zZeYPy*91+z?fGAq^PJVB|ETAGJ=3gXG>VwbPnoN&CAIkDcJ&U+SUR%R5Qk(ete8aS zw{yB1;y+cK`l*I4Mfs}c$&&a?bimra^lW~g$4%jqI%b8>e(T@RwBIqiYuP>773Nlo zO0FsI(v87vv#6Gn?{emQEYnx}(6jS-0@d#g$BxxlbSW0jGAZ_Y(Vlcf|$ZwMylrVDy~gvt%C{z6>RNb@ov3vC!hz%A2&TX0YR1 zy7Kw#bO<8RM}zr#Z{WHiw*Y^i#5R-lcj1G95dd)fr!Yz!Bt}af#d!6yaQ~`Hq!PJj8R{;hEuz06inV=>!X?Rc$NPff`EiTeQH)#L z6CWvm_*?6ot|tm8nQcrrz^nuf6wg zKkx52-s3r*=dZV8uh!aY<@@~%_kCa2b)M&SUf-HVUuKIcjdbrQJKr+(n~vD#r`n=C zd)IND7thpeKRLBBTJgR9^53vMvDvX>66V%RhZ`aVA%8p?AO-Zv{-M|cp#$=Yirf(F zhqn~uYL>L;M@91=P!l~%%8XY_ktM>B60$M4lRAJ|IP2PBv;!#&Gs|*Ch*T53f+6jrN+{A ziV*%5CAs#Z4HQ$PS)z;@`m8vxn0%Vvg-wUgOkGo*R*1X6Yd`$FVD{&>`U15d@0!iS z`d4YsJZ||)MYf^T7LPYT1qn6;c=jpxl;)T^w7hyAZ#bz3sms{*XkiZ$Nl z2_`YIeuJT5S-e50agwRzrKqUuu%jX@u&1kWXdw-0zun0yQ2P^rwc zt9-?(Sg!_y+yhUQRk6kl!@GF1A3~?? z>$|RSM9O6S!C6b?F#J~5N2{{ojWDGkZHV47@ zg1oroA||t=uCA`9=lv?iCouLyn}79@>B(p((V!zP>GAPeQvDWm86 z?^jh#*|B2>=Byu2N^-nQ-T{{dS@sMWRwBdD_(#Z_Ot2Dz z9~p;ZCs-M(f({o(l_+hM@0R=XfN6=hS8yBg=}->zVp6kE#CB=awfmf{SU)l~$4;LG zs6p_+>fS7l$L!$rYf(K7NQ9O&o~vbzJZx&p_t~Q3uLl#Y)YMd^e*XHm`8C&Xe$<1D zV#q;LGqb+vcYlp-m$(nHn>4jfhZcH!LcI`x{rZ&B<)@}rFP-nJ2^Sw z^=2cQDnDr#|MM5;bz85z4rF5X-_Qa=_L&T0p1$=oPLy{jtz-gj7=UV9b}{9vr<;q* zk*LQDZyhv!_Fz=^Wcg^=N)+g*g4)aG)jcsj%*D;twIP0mzYf-E!JeU-=HV@s1%l(> zrf>0JaON84!C&8IF;r>z{JHX1-JvzM@C;17+VLkZX~XzqAOD@jr}#ecIIJ&yf_5e{ zJ~wYlg&8I^?37EcUaHxeiA*5i%$fVhTiufX>yO<7@yNyi>tq-tyubiRHU&vVdTKfs zJK#?L@Nk~t;oF=E}UXvK|Xoi*EZ!~JTVmQ0!_lHv0(&wt7eBi(%S$9QkqjGi!%0+UIB?SlU-W3K8 zjj#HD{9a_ld4Bm(^C0r?`uaK#dAe^w!3w->m+q!zl(u3k)AvwqR}b5~HOT%maiQcs zih*=cXH3G;MJSt)rneBh(XtNWdwG!Gp1|kOQHDiMK50;^(bB+h`k#ACkY7=Nz$57_ z3VLm5WW>Sn2@00v{PUg0M`dR#Ia;6J(ca$9FbP?&wuS0VDj(GyRKSh`e!baUdpND& zNUy}A#fy6(I=PekpD?;k=?`vMy!>;W7Ker|Q!;cET`TzbI z-0emn@i^X0`THZr{w@=JtgCkIl79XAHQhtRzh4J2(glF=@M|THH8s)8 z=WzP7<9{%`GmCmMs=K@q)!=1iXHLkfC^_)_{T+DnG)z90PpK@{ZRFOmLD{pO0pY{36AVQUNB|m%#*-z_BcB~ z!=TwNh|e>k4E_09)vE&)MT%ww^+{*&eLgR)@Z|by+n}5CwZnq3KDU6rnF7rF85HTg zxXGz#`ef6BRdd0$s5Kn85EFKvMB`ZBFHlVIt@Ox-3;YQMj5;t(FvX5FH`kPtTmK@} z(v0DXjg7U%AD2LwW4I87B*ewForTeU%ERcQ0x!qpg9n&eEVItYNDm3%0|>WT8g7`f zK{VluyP0EdZqDu#5EK+O-B##mw#wzsIm0i4-i@?{P<)hUaERlX?xI`-@&dTZ5*xYOa7{G58D#d z>vrF7zGL`>3lfN|Vl_22tP}>72OlHK<%rxl2V+1wZV2XxmWSUN(`Hwt1tSaKBVF6pTl?czFDQ#Y-U5 zhQ5g`dv9SZ2y*6C<`K9dQ>qbwDG_mS`t*6DAp{;a243NlN)gB01xHR+{s`k4n%I5< zZv%_5TDHR7ULQMIljj%TS-g!7ochcXZ@=3zR{2V)X=w>+x-xh{_o19b{>m;YD#|qe zN~mc|_CNwh>2`I5{H*!&uM2gWJ)Z+>-fvCiC%NET=ndCGuN}dLDTHhSkiz(CNN!HT}taI z4ea1z(=s&VL?AqgN%A;P%?8iENRW3(k7yobDFumYiqH!JTexiW_{QxxH`I^zPiGW> zLL%AJcoMynB3SzX!UC@U!DUP{OXT4m=uExkBi4l^F7C8bv@j3ZlO-Zb95g%y=7#a4X&}@v5UPH4<2nR*4K#<5v0oZEa%FSJk zr05nVD`3ai4D`Hb2Aap#?ScaLSG9pSKjQzs&T|xd6rWrlX1%skUm#?V(p53!_5I^i zXlb)U^wvFCgP{b#PyIoNTnBTpDR&QlnBngYYuDSY*(rRusk)jIwZ~nrA!QVb`%^dg zZ`rz42%Xm3xr@f+POMO(~I z1xjrc%ZJv7xe=b1v2bVO@YwJ`qhgNk=J@#d zDL9X}Lw1QHN;Vvp>6?$ZyH7PUGjqi|X5TH!lZwn5S3HYN%k#&xH4>7NlFfr;fXtL9 znq0SUyS_SjHd78C5&JeA*>@N*RM@(vrWaVnkZZAN;XoDGYxm_!LWdpF@T~nf3)drW zX2dWHNJ2J7$?*bCZuI{q1$mv$g5pVlO^e3Nj0dmxW&Z3wXe17+eifr1ik&@;v2CwQ z5V-mIaEUwQ+OVmfgHT0N&rE+~c7p%)>(?oKDJm+uzc!Rvhh?T5)ceJGd-R$?I$iJp zz6UB^J;W*Qr(n+>8h8=iG1t(ksGy=!St)vKI&vTKN;dBrN182|?#9npRcRo5A@!;1WTO{5TO)@s{BJcV2s}iLHhPI1Gm7vm| zGHccvq&>vYse2FeAtgMKld}x@`TMrEpUTSNoW^8e4kC=SqmbE{E+3lX@EMKOgw_yP zL?aN<2U&qUqZetyePo}=B3I#Au-%e3=_kAdM((ZfU6$UMJXd45XH}TA2|EQM2*_rj zphO-`GYXiYIU^guIb)s|J3|W9C(+AKKzYGR-NenHpJ@Uyaiz`eD&P}!aC`ws^1}eY z2@GP~DF6dDL{k8F{C+gz>57xCB93KzXIpr&vu7a|NsmWYH9k9zQlBF-RxRJRnfESj$8} zf`DJAs;U}*>o|j!ftZyl$om24J16Q~liWR)y*O0^akM1W?>r_UzF%W~2e`-7(dI8ING4{_ z5LBLv6tt|VX)~zYGYD=1`)?wm6>T5=6z5Txe@sxPMQek8>>g(J5MBpla?)dB-tym%ZM zO_a6?f>0n{aP0o7#8AOqGnR$~+5D*y=^z10sN2$?$h-HSW#RwjS{va<_y8q3&A6hn z2t+WkaFBg*`0E@;9@SA8O_Vkau7O_o02u}M>jyE;&J}O5ObJm1)1tb3c`WuZs`jMK zMc`2=jwEi%r22pR&Yhq2NKk(-oOV42*?%+$szrheU=pNR8R=3(!YfwTS__eE((6IM)3O}`g9 zkJYrSLWK%DdM>~3-@PUXQ1FLbCxz$qty6IxQxA?lpkwUyv(VFBi4gbG!iLX;O*~HL z_s9&_0Ar&~IZJbx3t{g8L;y+3Z!yt=MQEPL`~eS^2S-wP3|utk=ppo~{L^Xp6j7`nFo&*Ygx zTueDGRY`6qGBS3b0Zc;G^YF-LiOiWZALyJP+%%62^u)4S#^NQ`AYnpl+kg>KUB__2 zjOMAM;P@WqYjqggI=_#vN=6drn2Ca(DTVTKyQAa6@PfN{-<^>!e(Js)_cWQI(M~sR z+%U~U#}VbKsQNx&aacrUfNkE@kQ$lVg&unC)3YT>XN-b*Z5y=DPJQs(zm6Gq0U;q3 zKvfU!%~*g>LCtO|D{q+OvNEc8uBXM9wed_9pPHCgetW^o~9Fs!Q zJB`Bu`AuLDmN`JRk;3Nq@#9dLOf?LXW`kmkTJv!iddPX6bU$`N)RML?Rwou>q1g?a z>ac~#F4n;HKmm>e!9!Bel5JyS5Un+b@d7t@KB%lr33qIo1@gtk#pPT-b=P6hum9Kb z7=y(YBaWU2?}H#*`tYHFayPP}yvR` zH3VBVX!dU?Tcd`v*}eM=ulw$3yRB7++5rTg-g9zRgq^DkY^8}O#&m?$oY4vXyc8k(_o$e|r&eDC1T z-EQAngbf`kVq5`;F{MD4><{8C247imxa0MDBPc$VCbgiZ2E2a)v({dOAdna~hlgio zDo&ay6aco&7L(BgCe(okaw;^0oyPAYJ@1bH(7>dwx1N17m}^-49?t)9?2W!ZmyCmN zvyx-oL{deg2z!sI5riP%m_;472FX|^%B_W5kJcmJP{R~?UlbP?69UZz6AoZT?*Z|k z)(vr!+`U88vxg0{S%OK`2pKp63JG3}5oh6eVURlv_R$s-B}m5t($IBu*d4&3C?abm z(EJ}QlR>c~8TSLL2WpWfb3?-Y9@$no78(uUxD*-5#*p9*1&b~rK`kQY z(ciUDMAJJ}xfa=nE^3%tckV=7EP{r)F>A>LF+`vPuSaomD#9FvTR=bt6ZZTO51=UE zMQvvbxq-yWl~(3|U=tvrs=5iYcBU}wKB!`Il~yiWwsFz6!c8D$IdNiBH3sZi7s$zj zGXH6M7~{#hgD{6bjW7nb;UqGdI7D$CsPQBgEm{Y_uPQcV#}Y#mWfTh`NYZFTRX|{%Z(iPVPcJXUT~Ec_ zu#^Bmk6oT_UMa3EK>An@|9|G(W8Xi1O=G|RZzjj%H!$}5e}@qHzhu7ScSyhgk1-wp z_B{Sq%LR`$q$x$>iw-D`H(`g_e@*7$MZC}jx^I%>w2{Wix*Z5`TLRo0hz#h0+-p%f zxYGwF7JoTwSG_)AIqBT1^LGsk8 z0|}!nW5I^2LDbHIZEjkG=CCWx-tYmO&c9?Lw0yU+vV7Ch7SmydiUZ029QW5tFB@G_ zV1w}v4KoU>E)?GE42wTh2zgDP%JdJ$EiAA=`&pLNY7EGtLF>?R|AW?{;ewF;U1}W! zL~QCONQ`lk(^moe4(%?a@EpKB-x$(Y z)Un7@XR@)e(SR$^>bI|`AUso-&+y2|c00QXI7D!AwSr=)8(cYpW`5&1-Kk(3WNMp@ zbalTis5s&O648q~MZrz-xS8$ucinf-X87WYptt+a_C~2y?9<&mMMCBRkOjYjf&xa2 z3BaJWfFD18`0h2&Tn1>HP3y;(XRCH!ffkdK!GSiAmeXlv&i88#vTvX}sK=igJq#6D z=`n?g!yzxBpkVCrdvK6CWZMrXz_AEOlAg17y|=VQdi0*r)C`;uj#m>Cr5*Dut&wRz zO?m@KFctp{7iQJRVbD|@FhQa#*nRL=vSJUN{|Jpj8S0~f!9jtKD&{%gq}E+J_p74l z{!AfNF=oorrAyOBT5s2CJXiaJknCrgG?(!ML_}kbA$#Xx{7@0mc+!gBKyDMfrfXqj zY5tf7 zCct~G8W&T zgQJy+Lxww_nS#m}{L|%>6BXa&ohQg70I>h(k` zXglJNl2N;|EvkSI_yeXrNssou1e4&8YlNDD%=LH>#I41kWHYa10fUw^^j2v5oWv>^?|_3-#BnV@3T!Qpn3tiG&Apg5~|K zWhK<^vSS?1r+gzIK6S%T+haC^6~)Cu&DN+;`)H3wJsZ; zG@PI`unfDKEn4%}gb0BS0-+lKlEJWTp4?V0LY0_ugxh6Lj-vM`mL;PXxx_)~swW90 z&h*;6HxHW=Vt*v)ty|@Qk!H}_(Nn#GCjDa%Q)6_%4k-ide&WFVfmCMp1@ucW{to0swYCx-l})de)T73vKQ>?l}2mp*F#fet`Tr%Mo3z^FDq0 zgc=pf#}}X86deULyJ*p(a+H(Ha*Oa(*?kuINa0DG#ghPK-!`dL<&A>(bJ5m2f3^>= z4qJ(Mnm8peFz~|F$fqc0vD+swx4>tXmzP&{Kr+pb(;hr;qugyoZD1pO>PK!hWn)^k z9lA=$nG6&ZE|$aE+M4n{TK7(B11G*s^bAJQe82A{vA-?xLwk@%1;6V&)RgKeRQ$YOKA-BiIo(Z}r z{`Z<7zT?QL9SXJ?HLJnXgek?Us;|gc=FV#$w(E+Y5<5o8r$GNVC3A0fqf4%(;{DHJjC{W1}7&V zm8ErcvndHjjfaX&Zl$~}zI1a$OvDHd`T*#B?l(6JLLopBlI$63n_s1A>DYi%%X5n8_-2CVN%O9A+CeXC8yj;mHps38 zKL7mn>l6Turov^r8~pCvQNf{3eE^Nw&ds^v(840(l5n7@5X-jb#}8c{9dOj+iP?3u%rm95$#fxl|s37ue^L>e}6yqy@I*E<=(d=iM)w{JWgat zrL}qZ(oaiQ>N-)`CjyNRTC8Rv{GbvM!B(wA4Ye#+@<7c2oI^jpz7YhbQdfwUpK~iS zHq6VfC`x{1@VwXg)1j4M2QXKWAGHHE#c{MsJPKJ3O#;;vphad}ym%4y#A>iKJwT+= zoY$>gd!Krapq9gnEd%$3gy;dB6n$~pe0Dm@U($ZISFlxur()m!ptp(o(xvuN=yR*D zSH=~9h&BTM6M)-u6w+u#$%|8GN`mP^q$-K7?g$RZv~(CQZ0QIilnF7QTSC*n_V)Ml zB15WknTp>9BF_9Nxvðe99nmNRrLLIuIi%e%mjOF~)MTwv`R02T<*fp}U|jvhUl zq~6J6EQ2&Z^y)@xKLT;G7@Tr~_$=&(xSulFD8o-f%wlMlX_08GjlOy-fxxN<5A?R) zb;|YapsGL*!&ChVMM90 zR~=b@CMG5ZaX-8RSjhUOnW!sGZmt(@cUr>4p|{SdwFdq~TP~*?9kR&Ui`q|M!rSU< z!-3zw4LX7LU2iC0pOLAFOOyciB+mo{^2i3_s9kGewb|10G{QvE=5S<0%D2O=$cIg%fEzfERSZ3s!1FiQ<)D);$}n5IUC&4Tv6|VSZZ+k(tIc5Rk7F6t4*=!@AjI=D54=H z__>{}tuFA9__ryzI8LBd@4tSP^O7Q}&5!E?P9Dm_2GgDl4u|1bt-e?U$kFi+RI>rt zw4}>ukK`sD6acX-_a(pHa7^Pxh=Zi9thRPT{{Fc>)OyyGixHA*oo?skj5W8wXY(SC z>kFSAJw7LwENkBlg$u@D>j7P-G1p|)RF0*Uqp_h4$9UZ!z`c~r^}4!#j9AU%$9kJL zUuLBX7$yTD-`vcDMKi+tX*J#h2LnX@TFw&)7Z;3J)vP<{?*t4j`-vbbYoDOc?)Eo2 zPJp+qQ`?8g$H7o_8~hw#)nYH+$X`HGc(w&iew%j0c#1eqFDd8S*hGmaI$-jbJXg~` zWZL9eB8x!A%Q25@*FI9O`3NFO54}SPuc|;oQ3^+75mQ!CvAQ)gyW@x#ys1bf!^h8m z8p}8UZE1$J+00yD$UaBifR3$QKhq8X4clN9sCNijra;Sl>$SQPa3g`&6J@I{_hm5P0= zVTJeltcmsMR$CY@n3tcgCM+14YT6cywBin%eHh^^5b+ECgq%zqsIaW>fQi;l{tS>Q zxQp!r0KeIbH)nG(N*dF*9W33BIt8u(SouH1kd*sJ*Vt?XNMbwjkOxM76cV&NkWE3bL~{+=+m2GZ^g%HQyfdMNr6 zCWO2eif-GXzY)QP;1kMEjn14oGY8fD9Ic8~S+_-Mk}mGUH2~@;2RQ98v??Iz{Q}l1 z70T!UEg2l_|N2Juiu3?b3x`Kf)|T~xhOXPMiLv39qK!GbBOp1rs4>lXYGq;9Ec%u& zU!*~7N*Eo*<)HC3+%4-JR&$p7PXpBsHT1jumIL@rYP*F~9yZCNI$Z=?p_r6IyQXad zE%*QmZC!979Ors({i>;zHFl3;+YNBvO~(O^BXtjBQ!}7*uNsUIXC<=?faR;~eW5x6 z2|UVzvuDo|kh5pAj*jo967M4F^0jZvpZ@RwOc1v3*|P{jw+DspefCnU^le2lsM4@x z(TmyW899`W)j>o;!srEL3%qh$r8x@m5F;Lw?+7$vLd#9N6=eqv}f3WrNM6lSz9dj%awI_gw+cEDiML zWARVLgCalL9J@XTtaXIp>ZY`gQ&19S?|!)+efejQR2n=- z(Yx@$)#P~N9m%7w5b-U`R1vhudRQlk#d+Ys0Yd+;+=&<~gdX#iD0-VAcYVMH;iz5> z+$s|^Us$AXeYWZ9!99JrR%ib8h=YsQjtnzjWV+&3q>-$^pAtQzhxxEA?SU490jy344^)OD{yaU9RQuJVu&E7+c`s-Orqnuk1Slc_E!E?Fl29JH2NzAto_>&yxTpdzl6MYvr z%Y%cU29HLA8AjbTUvrRO2{p&A#W)-u$!TTazaBs_r3(xZt_ukR?@x)wr#z!QwDMN~ zp6bo+y!psbpyt5xblY}Aj%_zW7bAwA+7#D8#8toW@Ng7%7p`1;^!jiPhcl9}P~WZ^ zR*YW5#G+zi`H{Mf?QWdxvcSU2QGk^LwoI6#fCRxwioba1>k2k4;4!?6W+*t^EkV~uL4U^Xk35NZUAUfygLvFFS{-NjhbvZ`R!(=gA+Tb8xciJzt zJJ~$#GNWXjl>Yts22ID|bxW}bQ01Yzc5B=!A>jc^%JbKUdnm!!>|eXzl(m{2l3ef+ z`R!P#t%k_Y>*Jxb?q{Tcpx|!x8=Ct+22`{rMf1tL|AP8Uu`%AT2FG-chGhAl3saC6 zhKGN9Q(&R?J0ocW{9@0EuR^SKc~`pUj^fg>oJz@VxMIM_ABfdt&y({_xUUOzmVm16 z>F>W-?d9b~d?mx*+P1U**&5M%AeX`4lw%*;KPQq4NrQG0tGZ$y={K(~#RqEzJEKt2 z#_D8aKpno^glDC(eD8DgZ7J-i;X@|NOCCd*i47^NSnxA%Y*WjTlF`9w%+pM}HBf!% zIJ<*QIR(G3ii1fN7B7`F)ER8tSTt5O?ls$ndD-tnH)xp1!s|kb1b|H4feBg3rU{%a zN58EMetzid>!}P3J1$jxRy+jS18BSs?3NmK&i{0=<2KpAfI+Ii@e9=;-b|z1eEFg4 zU~6FheS32rvVXKePGjo&Cbn=4dd^!LF2e!z93aZtn9Nv$gF&6uJ%Zj1z2W-EU9mFJ zYjAnbXZbM@s74-#^NDOl2t@~vLFz}MdLoG-_yQC`T&?244i>{-yX`K@WbK9_a5KnY z8OHG9qNL0aG!UTk!r6ZB&TZ;Yabvd;oAct-ZNPX!ICDWDEm&!xl(PN)Nrx@o@5?yG zG_MxUYa2dWe36AbX4*VCT|D~tpFS;EskyV%hmFp08#}ubDA7JZjCJmD#JO_}gQBTE ze-9R3`yWKbXS-W!`r1E$-45$FK{Db{%Xom%0CTaZb3$aqThz|L`iUABJvOWeJ??xX z6!}XuumypN?%op-n``jX>&34~7GXa&>H}XNoJh9w1Mc!3*r3|u>(e%FflC0@eZdC_ zSJXPXr8#G}RF3@8`bJicKEvZ?24`;gyb zd{8jNOHOGMw7O$^K;v*Q3$HcA zpUA>-g4L-S$gma#RfBG*?3QlFXK@@T zB5s&_Lvcw^7In2CRU(Vj;67wHf~e3uGOXvJlCS#%y;LCdS}FDX7#KqoI{0hqhS`|R zdf4Z&@i@EKtaD39;6YvmIB)Mic))=j8P4&iX>i9O*70-a&vS3i-W8kEf``oP`g5hO zf5kiK6%lTU>PJBim{f?dL1qed;4HB0eh3KHP__`m4en{#)%a>spOXTYjKd&sp##J5 zaz`JUY#P*#L3RG)7~sL8Pmk4OH3!}?p=uyK z1sVqG$0yO3%ZKu*h~I2zw$er0G6{09xrrME$bRne&MIfqpYLl%E(Zi~idLN9Z^=3G z6n9AtU8rM8U*qG$z}Z2+vJ;iG+F~i$dRJl(g8-;HruN}`_cSMWU*f77*aLLQQmoe$ zYi-%Z2$_wBH`^^k^^(1Ez;Hr^qKh?RcjI;W@A1XFBk1iC1*FKU{Y5G4S0R_^`S4&G zNvUBsU>=1hYT-HZ^CdO#z=2=fDsrm*#CUA42S9oizydBDU~wF_oX`O5x&03gjmMeJ zP3bk6Sy^>zwU7a3?r$^1gR#R50thYmkx~%~0x1%bf*D?eKhCF{nwfCSQ&v{q9FgXD zS%z=IwddhC^NM^8EgHM$BrIvIa1+-}|MWQe+#@+3*mszu#LOG?=8bG~I7Tt?ip8?W z#?4OBc_VV{rQjJ*vW1|`V*}h;iN*y52>lU|$U6gz9C+2T=?xX6UIK|q_z!jwfoAt1 z0kwHlI|IN!4tsGVNxQ|I_U;YA!f6FSK*_0gL)351x7dCL$E|kY{3Q@>!TRU`M2ofZ zV)OO+p~_huB9zqo)!La=QUuM4`(Sr{{~K%_tgi{^(UTe4OW+0G-ldgCm61c91dPA^ zS)3lT6jlsk<)g=i5`#e8hhCH@mT2Nwf%k(sq=Y zcQKv}h+hvD^L+3?$i~=vO5T1c+#8y!24(2xPA=|e1wc4Y9##|?{`u0H7mQoi=i?f$StEBUY&cUvT=RZ<=W1Xz6&@$OQuq5U6 z)zY0~Oi5PdKro>zGF__iGA7({0vb}wm-FM}f$yIyV%>Mxk)BeQ?P!L!6dI;h$E5_M zzZ3@-sYy{2Kk7(1sbV5Mmj>`w9a)gt5OE`=Yk%D3ov4+_GYdx`J3xwdMz2|P*T&|o zX+cV4JMaie%>LImK<#baXQuto(q`-1IiA6&8%w&1r3zYdp?rlU#W=VSS zBlY2`<{h@SzTkftt=#*4$u|-b68PW+v*QjKYqNp@~yPEUxRE|j+Y8(rTFfLgKuaBlM)Q8V>_(A&{U*EqI~4H z*YScUbx%$f-;s4yi#20LM09lOAq(GXQi|9|Csuj>On2Jmyh>r`Rh2$2HzY|9kVe;U zyS)BuUVc7p4(c?VJ!eiYE^JS8?hF)6chUam6BJvl#L+}CV_(V0NmNxNra;ubPx6Yk zHp#~+0!@oDn#00nY#x7&hdiigh;LSPbHCh!6UfyE20Jrc(ddZ8K@=wl0Udz%Pv9hv z1)`oOQf5-$Qu1eXMP=>|k?-E~pSH_4Td*aIhew|u3O3dQ`L=sOE#EB5`6 zL-$-_04X5w&L@bqa?S&57%X33FHikH!nY`M!xu3=)Zsetmd9OK&CR&L_vxcY(~y3_ zXYH8&45NdRTa1m*RkQG~QAfgPh=Il>>4PyVTrxmimsTAhU_DU7DERfA}n%s5$WOwLug)6B{D>A;ij z!-}#}Qc{$GqAXd7GB%J77p&$@1&5rkHeZbFh@7a^&(_?0E&uX&@FfB_r-{eE4XwG) z-xlsv^_P|R1k|3!NKvK>tIZ_|o|{7($^-|gW&1l;L3+pd0F$XopZIw3#y_cqD&;q-m0P=C$_4i&H@Tx@snJ+B zGO6F}lXGLF5X{fSX!M}7t`k#D# z?bcW|xtg_Ga)jsXxOy?_{i11VU+XKlh71KO_T^O$EdKR(f%;7iYT3X5>>ctk#yNEa zBV?|syT|qZM7>CPGFY;@d$$acX0X|-iw&Watz%11R$q#M>k0V<PuJI`|A) zbfy$FpVGFRJVZF^mB9ihM-0|i04Y{evGkxQ#nZCC=ei@gmU^h_v|SnwI<2T$;)l}) z!FAMF_AFYOVT!}~?${8LssuwuX`k#lZkt_piM7jB-ahC;>RbBqI5ca;#&1r**kn~c25xSUETaJy(`fdKeMTkr+@iRaUhaNfhU8o2V<~4 z{bkCMW23`+9(j3sD$QIPM*1pp=__^dS_E@eni0m3nm@}CK>sedv@td?*l{BE)!^_d zDmIm2)V9eZ{j8H}?ARkir3%Pf|JbzY4!AoyAB2R2o+8d}eZ0c1VtS#+PfjqB5CiAz z*^JNvw~do%QxC4a(*&nL%bw_`lEaCc}ukI`s|c z-ZS_%UQ7(qrRe=gcx=T4_eOEHs(5^! zDj_plW@GYSOaKI#5|#PZUO&SbF@%RW^^jfj2LN#09aMDVVmfU>5Jmfd_cHw zRy*3VAy?7CSi`mIM|!lBXZSE;;BC@!t`KU&=w4AfSPt!8{C5K12|dQGv1vj5{^-+D z%z||}d>F!7h@P-3@|3#qR^lMIEGsY9Ww;>HIQMu6L^vt;*t9^Nt$tnTZl=5vawpRM zk@AWZ6q}o`)}v03AK@1iOv!wLh_h_(h8>`G*nKwTyCtQ1!dVGaD}?wB?gY+t-v-%x z8m+aCO`{`2{viDru}QN;QkC}+zNiuUduf4d>&c$JzR>=i4L?Z`3|yh6Nnt#72PqW9uZG_iuDnhT10b934>tCBN_k}>F}x`Vkox0=*`0onwi$iZZK{B*kV@(rHj zo*^?p+|pZeQyDQ-7cZI;2DfH;m+XHure}b>p(=API%neFp25``w1IaFZ$gRnTs=oB z1#NfJn6mo%jjcPE0@kMn^fTbYtE#WZ#pwbBpU6<4ic*FG9nf!z#RkZ~9JlVn_o3OH z7tK!;d7ZpeaiF6FF~h_LZFBFTe&?ib1hoN42n5@V?#zA7@+HF(&~r^>GGTh2Co!5Cu6aOZ z%PK3^FkFBv1Ud=33=&jXk5t0QJY8*@3cjD1l3(B6M%+MsZP2j+naRs>P;(a4C&C#H z35$!(9xnX)cyA;4M$-~aS}Q*wTOb@a5Lx1`=E6OM?N^Lvk#Rv=0#7HNTXCf4nj2T# zV!nr6=wLM+snHJR(m_Mun#+^4f`nHnH#-Qcx)6=a$q)vLo^TpEWRkXdS-XvynOM?V zpWhJDb0TkK(}Ey#D(H8Tjxb}bF{l?Ejr8?auQt4w5g8dd#)xdq%W@NYm9Tamn-`Nv4%nRxo^Yb&?Hg{T*BwW;HI5~P!hZtOC*%|WU z++Eg-tRBzAu#aQXq(HFrSJn^PCv+Z$M(4JvFiKyNn*at%m?N1SDw${Clg!?Lh;XoIX{mF=7tq5_%6nvai*aVn_km=n2(Xof4YzNK1!OS`2cz`0qf^ z$TsE5uc!{4i8!xm5t!WBYt{HI{ltkA131!ho|K@)xV_*LB04=p+)BdRS6~u%Ma$!K z=5%5UKoUpIY=%QMFJD@8PxMJ=m8?Mg9J?^=jk>Ax!DDA0>9tMkOD;ZuB$9OAW%#ol zQYd+yq^2%ii|tBgb=gKhU1-lVX`=b$naB6}*uSWKKq*IdzqJE6dq+`Zx_Q(utH8-< zC0(>4Jq;^?#8|{1R-c;AAE*3wvZ$w+@G)RVYXe0|(^8)YjEFeK>}cxkqFct*214kg z-)7k=S>HQ|mri|u6|J9%k%dmCBRg0ozr2Ve+$9Kt99+uzeuE^_1izA!5?b(Gx1XxC z|6|Wp>`pxLClI!4`HLh=+703ht+-BcjA0Z8)c5;&n{ zCiaYYTzagA#Y?7?K1s(>?-=;gG3G$mUO@>y#6@ zG?9c*xJ$z{P?j3p0q)_e0m(QM_PUVK%#<`h%x0@$9yN$?IR92p`J~pxu z>fVj$yS0P`)-<49^r4viC}Ezu>3nCwVvL}$v<54r=1Pt#q{R>J>lD0|{OM79LigeM zfDY-*!XY)Ld_&_5`i>T2V6$97LBl!7o8_1QWWmWsD-L#qq5nD>$oX+h1mK}3191a= z@pKXOO`V4bSIXY?Ob3^4=%xSMAQb5P8S;@J(E^@4y>29j<+O*%qF@ETq@QtT0T?$r z3+X9?Fb)L5l~Z2A#jpJs8$u3<(R|{V@pFFYQGh;X9`fC>(AMCi9|iD)Rze_Gqtl8dXM%#}A~4URH-bZwIDWvOW4)1E3l1%Ac0b(RB#F287xV_`(n*;61JtSM zp*5$JQF@HEUIKWh2FQPCIs%{tupcTIMQ(r;3y@u2%Qo!j!wMqP8&tm)%$`+zq*sRq z*pVVG4##R4Iy>OeGYwrd&+-rMPLsM}jkF8-+nB=P%<5&5hC~*jc94N3F&S@A_ChBQ z_IVsy_iVwXIaR7gwk9G)I*l{(G};QL;0z>D1M%Ke+F!|)Pl^7};%lhNdEV8c9SFNP zZ9N#K_#uF!KadO4=D84{xnb(S4M|Jh>krd|-clj(3NGJ5P|nx&l6hn1ly1=zxHh%O6B$oy4pua2iN{9n)hrtY%#z z@pBNUq)FAjdl~L$&Ad<>;Ck)>q(ZVG|K-!%-C%~-tXnr7OUf7`RNuWbP12E0+W`_o zl2tyglGas25+Q~ScnUR&;wDLnI9vZ-F_Yd01HqgC9ax*7BGA>VL-~p;i{|Q z1{{DF&c=}P5tuKJZeL(H(Mos}tEiiMfnO=`Z3;$AY3cC^Sgc!(jsB_*1w>A(h0tlp zt?T<9&SleERmDE;()<1A=TI?l_47?DwFDg8>3L z96&ETks(<$DpStf^AF15f1zMs#f2b$AD9Oejbs*x2_0+fO@juy)%CLco8p7AzzXsg7P4IhCoMj!b7~h9a7m>MMqU|wsYo3j*0eJy7(s4 zR~|!ieK%Y|H{8GdTG>1Ef5X>^f1cX}eN6yat68kM*u;udR~s z3XSh$|KayS%~oiFrzHXB#?!rVElPXr-|Nk74Zd^M6`g$4#&L5vNad2tqIi}C29mAj zby+*sL69h&j7m~lCH571p#Sx0|9@?{h?D39mO)ZK99>`wC*f2=DJcMOt`ybw!EMkt zp{c>Y;U%Ea@B8MDJ;Q(&*9wj|%=pTP8~#r1uz;P>ddW_of=+wvZ98DzWR^@e9b4P%LJ&U!*F-Avy5!hu?y$m)5*Ien~%`{L@V=FdS8?TW7-zGFV%F;KusXI!Kw&y*z>1*pUBH zgBxgmvL7Kp5qES9qSqbQVaff;T_9$a5 z0-JjsacdLqksK`+ZO)dieS}6bCE`w@J|V+9z-dohui~c{_>Fe;haeS_Y{A z!Wjn6FbbglgAY>S<9A&d#vSFT<>|c@dj5t3*=%Ujh%U zcYg-JB*v$zs*3&Z4raZePJYf=GN!vhIxHK~6jhYQ2PFoeC40v$K^P zb-V}OPe!I(3VIxrFe>b6NDmTntjNvP#YF*74mH-2Bp$TJq~^T@X2R)!3b4ief8 zy9AWp)PoL38XAW{Yi>vF)nH)2A3zlpTUyZHqAVZa_Mo?%IUUhv~rt^wUp8 zcbLwRBX@rsmeqU!H(qf0PHo7x$(eO;OW|qsZ%qIUTZRHHQXShh?y#Y^sP$vaRsbTS zqHFfEALM}N@3shzrb)DDjFQLHO{h;BZFvRg@%gfLGAYp#TBU;`+0J zfn_+V-+lQK0N|jszMh9sfS(u59KnxL7ZVgRq5FpXaNo{Yl>lVP-G=fdz?ufnsT%>O z>2`D$t1X81$GTH^5Q#XZoyGiUfj|Hn{%5bQ;vfR#$pac&)oYN<*Av>4&p2odewm-v zO(vtL57s+V^jo5|O2pa`PFN=ZJ_lcyZYzh*1uI_QrlqB&!-{W+eJB$QsTq>N!ZY@V*%|Hu08SQtR?!!OE4w)N0RMLw6lRM$)ZdQqkwl zDhPZE;T7OdO1a6EG8_>{|9KWl>~VX%tU(4dFO$u@a8ih%JA}AHs!sb=s?Y+1S>)&I zivTQNeHnpqOSnl(k9{~f*d|P>>OehujasLU%}Gm3t3UbYiP~#bx+d*M-&#}7Y79<6 z?Yag*r={BES0vW+beuID4DG;Hg@HT-I69seKy_F9gcs3LFUiE1V=)3~r!L)eQ@kh;#w{ z68B?Utg#LJ2VkO;&${6u5}vEI>fo)6WVq%ne|$ZuSQ(||%T3<+1hPLPj|W$NM0nLb zXo2W2M$~j4y;AG3@fg(O*>X!!qvKWhbFm{=gk+=e$5wr4sgMm*+g4+)BdT8~2dGPE z7Lw`<8EzBtP?X1ZtrGK=m@3{z(|_ZTuewNMwqxk##(CDB$FVHj(daEA5w2r;mImE`j7H$rG84Td79DyxM$wi!A83ppH0ZWS zW9zI=HGNfJS#{r}*JSzG)y1VMvsf87C-7+FVT~PD%D6cVL|biqK45RVq&+A~$)WY? zMTiUcHx5DFW2X$)boPqRK@yNd%HEa%4wOmKgpBfYS;}FVN^S zfdMi@srh!~@9s@Fbv=$AIXs_d2Yn%&&d?tu9lm($-JUe-;knRtK`9c8G9G5QK5#gO zc#E`k?|_FeOG55Z3SwE-3on~eWRzgodQg0qK>#Aw?Rwxq=x7NDcv#E|(q2Y-L6TG` z(Q-a}kqtk0`-g6g?m{2LY<8-gFyC?dgVb8B7=?u(>t0&$fU7* z0ij&^@mefA)u6l!i)sQ^a9GRD`1qd2`*gHqeD52mK8N;?V6Zb))vW7Qc!s(%L+DZ@ z{I)1-ZzVDYrHi}Gz(}BLiw!_8#_T1sAxm-G470}3;l2$NR`FC`9DYuDapEPJB~vSr zm%YbBc4~vJ_Waexr&$hb8GZ{Dyp4GXg4Fe$I*KM`?s@ZW_D3hf6L>n)gQ+>y>)(BC z%9)1sf5|B;cwoYbJ&3A7G| zFT~&!ullOG5Q+|B1;dlRs@E)1Qxm|68nz8rpVf}vpuv|!?s9N%1dNS#7@!XW0imE#nm zPC@d~hm5oWY2K}!wIbB5iI(fEHWifk2)Z$B>Ji|FUmfupZVd|DUAyQVS0|Ju(IT1WtXOIXI#s=pPFm>(RK6lKhMS8?A*u9Rg z(o%^QELc*E6i{G;O%lWc9N_Dycdg6a3eYTp@u}&cM!H6@RLBVxnb20{VJ95bG)f0U zT9Q&}_vPk3IW*?R|0DKWFVD3hMf1In0RjMNEk)XH0_{;UaG}_E0U+=Rq|RTZ5ds}G%aV4srh?bO-8XHGSfV9^QK^(^mKPC2a zuBTdYxR5vH-TZta1u5P_;+CT9R1O3jiP1!c2#P&hmVnOXM-C}jWR1H6pO(~@OAiUR zJqL~FWLyCG+QqDlEW??r1G2`bBzgLLbC(xQ?@NEn7VN?|z zO8fTgNv+t)GvUwLtuw*(q-||OvX(VxqYU_0au!4V_6&I&1o*n#L`&%{+qQ55#pj|;(2v*17zZeEI&(^{t+z1ql-kBrqHmSJO+TE%l zHTKiPVKIuQ>fWs-4C(~Hg0nyV;0u>F{%exJx9MY03+Hik9I-R`hrfm3{Fvvmo8HR! zuhCOuHe)gq(Q>rysSz1Xj{m9)IQ-ul0`T_-kNtfdlrdp5{QrhsbS{BN@ za|DLA_nOvXJLuszzb@6k7A;D>7G-aYv>?>^Eh0m_BjGDXomdOUbH>VB{W-0~6mgrl zLK|}i;AB$xO|%POeG{laZqkIefO7VSBDsDDXIU@oAS4jLWY*$N5g=#@)0kdBK4^)r zFsZi)7sG*5Dhj2A$DUWWw9Ru7qo4_887g1CeEI)j z>`dTt&fB(sF*9Z?v)uOVLq&xWQ9{PrRYF;krHmq#y&^4Y%ormHS6M=3DXFZHwX)r! zMQIbVL_~!|St7mP^UB=!{XFmcKJV-EJok)|>-sOh-?<#eaUMT0wkA>7AOn9XkPaRL zx#t#wl8AXV{+Y`TeuTW1nK?xOVCo|}51DCGIBE6)A@31iPucpAuaXXlU=+PVojU>z zfpSF%<;_`p?K~P##s(1G@Mz{`KbH>62K47g7U2G?Z=ZAl;k^h$kt22Go`++W>E!ui z=fRpjW}Mw|dty9kq9-0R0dU@Pm<+c0KvF!4Mk4vt({q(6a$<2K1!OT)SoBe?;juDjNv_T<#JuISPrtVZo=b~hvMx=*Z$4D2GCZcVQ*|& zv*sJlGyIY3w3AUQ;#M!c{@_-hnHQF0x66`g`T>A+le-$s4p?wMGh30qfe0ihg(863 zkId4esCI|y&#;iaVKG0_Ymp&jyjI54lJ5whe5HtH>PF0juEgpKoXbdAF8Oj3N~!o$ z0F4L;n|qe$|Aw!dSgF#_!^+ASiJV@TU)snp*b$WOPjYVAhkQsqiBwe4a$sU7)8*q9 zrciL>{?wLtHwHHrIlzo~FC_|T5#Rf?g5{;vEykk7ci91eiMdffAZV#U#WBM!K~o?h zCIT!g21^Mf(_xq)%N9I&@=ri|@h)X%t~4%L63gUI4VZ4_ZgetboN(=yNPAoT@`69H1PajxWGX2eJg`nr=LQ-V9iM9n9G$d?>|U z?i4bW@DmdAzEA1Q<3R>Z>X3Qv>nYjM1$NwKt^fvp9A26f6XT9&m7JRCajsmr{ zK`q;cX-Rbn1V_O-BGQe*G4T~|Wd2NNj`EU9oTapf~a0F#m=EzBs{gpFp-`yv+rt5)wv`0B%GnTM1(XT$u25URSApKEK~r{TUuP z!AYRnoQ4{v6H9O{s*BX+p3iStZ(ZQw?mm(cGUoQWZMm1?*O9fuEY0fAro-2>m>+>@ zl8Z;uQ_gL-3y+X#CpQdFz1BV~WP>0|bZ|o8ArWe%fEoWyEP=tT270Nhf&=p#NW97b z4oygmPj4h6F95lrs^z*=wD@uhW!XqUn<52Wzva)T!5B!9sL`dh`n=ok>%uFBE6k|KM4d>(hJR+7+Fe~p)8%K`=)-!WzivY^{d{T>T+uFgH&|x zf(o0#FR zNz0(@N%uJmLUUwr& z($MA)p5@uN_i_4`XP?GhL3f3DdMB#zi?DBE_~vp4Ke-!wM>mI={Oa228B4(Ar*GS` z((CB+?|zmTk5))x{Q0`KGY+?Gy=c**gkLB3#{jb>Q;N-DlO!`q$10=2sD`ZKj^lPI z)dHmf2{t9GLThi3b8_v658~sDIoRg!`=2IB$SnMe#!Y6)S-k|%0^{pAz*P8B;tDyy zs3UuA4e&V=hKIs%xI;ZD8leqEHozbB8S1dckoYWP@fyI*+S+YS$X|gesK?WT6Qw#t z$SPT0Vc2)~S!5#&7z!>gI1+WkKY4N@72}BL_&#s=v)p=)7C(~fST7@4~^Y!_7>1ZA%;GHe9R>%HrKqW$0dS1`3c?u zQ_kIYmdvCOWJ6y9CyP#a-m-BRD`|X(CJS)30@2T^C{@bbC$ue1Ltg@>ivO-^mijHUc$pt$1!VkGlD*a<0=S>JiL99&svr*9c^)i$b^>BZq zV)}4#YjoVvceVJdgf96y0QH78ozvYuw-4I& zmr6vaItIm1F3YU>Lx9RnE%~@$SKBhTH+Y|x)$=~CqQH)RC#Y1x!o2<8lOPP1JvPUi z!?5TH!#6G2jiVT&5iqK;g6-Et!Er8oJk}0pr}SfDSiGKvMxnd%bRA^9^vRPaYhci@uMSEN|0fBv*GaZhvd$~y98dW3Rkq7_MN{)3d{JBH9DqS3dzl*6>g zg+4v^?wyY*Si-0C(-~*}kjwj#J~PF{T8TW7^?~i*F}H+b_h%HGTQn^1cBF*IhfvQIi{Z z_Z<{6(OPy`Ys#|YfI>F+GPB2cbMy-`3X}gzs(2rH$}wTjIJyj`Lwwf=-n)5~BgXViW*$}S}9W{y(Ic2pR8PX?aqC*A;%Zx z!>nhYhpP$6)A)QBb891bY6pif7(kv5*)rEENlR06HdBqR*4WcPKO4K}{Q#QHaEpwQwNifw~ci{c)b1U@4 z8Z((5?7`gZ*@#QJ^Gtf0X4&X&NjSHUHPgG}YNO*6P6%yA4!n2sW?PyuIp>Wt%zsEf z*ly(pwrZcSFr3KgfxcyuecSi_X8pwI`p-UHaT6u>#1#;NJ8|IkZ`iyfMQ@c>(E8WQ z>O!*F4TROQNzkM=!6OgP^--L)EfW{ZS3C@8^p#U3+?mX#-MVu}N?92FOMJ=FvRtj( z>k>K7KvX4zPr2~cT~~~YgK$lISv~~cjNqn0lxi}JEQh&}{(J3gODrU&`p%|Kf(ShWM1Oe6&8Izz=5&_MT4X$2CI=HFEBq5dKBNrV$g=mWTQsno}hoN zRBHSj)1ZcYH5{1CV~(yKlkF%wT!)4cw)ZVx{XM7L_>;K=eE+3rM@MdyY$Zos$*ErA z=QGfc1tvm^MgZ;BC{_-TmqIjuBQuG5=4jHYG(g=cbaOwc&+VIAV^ zJ^uBdjonLI01HeXK4Yt_Wno%ehzgT~4#RF^WE-&h@ zn(66wxBAEA7P@VNlHaYK{k((a^VHR2T3Pp24%k1rH_Q2V*Vel0dy zg4zl^z>>mT)g}2uU55U7&Y!l|iEsnRuBFAD0*vtF`c7N={y6*Ov(gxcj1YPwvbOX@ z_56;`wDlNMjaG#kxlQde?P#~0sjD`+bM69M7Np_<=$|sSuEF7Mswjeo=E~r{0fWX; z)|>z$N+waKhxFW-FzVQZaQ8`RV!jRu+0^W)HQGahGnIhI!yYs(YK$8KUxkQ zT*UP1jLkjPiZ&1ukUHOj*cu6Tv4v|$uDLk_IYStLSq%8qj+sv%KbCK%A6e;|vv$(#(%+p8c)67$9H8bQ8b+Z1;ARu!Ij)hhS}0*#=yC0s?iw4Po?Cq zx$PXOgPoMjhgEtsQYbe4qVXu?&Y-U~>=9FG>f{XUt57Js+}Hg1XF0Hy(W10nChqe- zeE1LxP87T92p!ckc^oI7Uf;Rx@y!5$Kw?~Y(9`tkGiFrqClfg?Ws+|;{4rX`WvR>I z2R1z2|8q^%ps{yKe171=FYLK^>#Rp8E{ zJ}03BqT!4!y-$2!j20@N74B^Rw7gvR_lw>5KeTuu$9Num_0(ofoAxVXU^s8Z1HTwV z9qm8D3m_3Q3ZcDBjfG{wIA@Qh4nw|o&U(f(E>@elE9FCJ$3wHu4Ya=a20<)m)^Arj zz;#EOv{PS;dBt4G>Zvq$KzW2RMnSifmhYtef5kGKN z#be$x=|zVkYcO+pG)9gsi|B|cb~&RywNTD(VSmqcbBhOy zO~(w7A*Oud$A1m69UAc?=h^SOapOiz=|Bv= zi0x7H$18Mfqg~1)8#Rymb5^`FpUxM1wfJT%|0w?61AkF0H#CX>AGJ$lC<~q138qqC zFbK-{T0VPmiX&*Tc$6kKB1!~I!Ln-D1%LTDWOq{a0`)GVe z`*>E~pGB)Wq2Oaj}E-X%nay=?#$Ci&n|xv zhpuL*doW2Yo(-rw{lo#reYdbzO^l7{cw9ShWF{p4jz>uZfw>0lRW(n_%UzefU@ze* zWQIz|fdX|3L8TzsX9b04-!z}jL%k=%}KNCaoAtMsp*1oKbQ2`m3#R>psh zV@y{vqEez>-&tp)s7GX=7T{(xnz}S2(j*u4ZUFYbXU?1n8!8J)p}6m%er^GUp)?%W zXZ6N-K+ZP=3)-Z~ku`2q1ek^e@;+uceq!6TYSWRebpRc}d`eKB*$&MzN=tDH=jBB{ zp9wAuqIn=kYAK=FUcY+fHP4sricZHNmbaF^!zey#Y>|ps zHW|*%AeHDE*18ircY$Fxgb=mLBGyH4q+O=;h z9~{XHv+^%kv1G{&beEIQukv$zKp`=1PRqcMri_U#@_q5_*#Vehjpo5<%3@WOz(tAp z7dxKg+TxhRJGPLemDCrnjGXsJf4z+QOe>L*HTLv`EuYk1dEw1mi^{L%U!%H=9eZ`u zmDI0+=8as6ii&ogK-#CzYjl*;d8WcJ$@o80UE2?i@V#TQS;xpFSJSa-qDr(WhW4 zYdGn8adOR8K_HaNuua1BxOKaBJsEhua^*@{{E-wJ(w)QLLF$XxyCq&!3L z10I(0Wuv26c?L=*rJ0tJ{bWs@#qVQCT5jA3m2L=Fz92m4<_#BLY}B?> zr{4dqVjB+BuWVgWS$PEi@b*PChj(z62N&N5C~yGDpejmg7bJ9G$EzMD;ouhw)Zlmy z0E*&mLf}&j(vBjql%R9~abYcdB%U$@mAUlr9v@!Mox(^a`k6u$FbAkB&M(PQUsN}j zuslM^hqh9K$V26@nkq+_=+#5Vdqh0sL&* z<;x{!-q66wBWkF2Fesg#?42mpPQDwR@M2%Kd4vzME!Uo=YChpAn*H#@z?1jx-J3!P zY<;~7HTg;(A4?UYk1hHR!<1kz+ceW|ZFNfUTyV+wes5r~LJH(W(!L}a+yPYx*m%j& za)`PV@NZ|z0%A~uN!i1K3z6-(mqLk4rWp`s*R?Ls&Zrr%u_j zNTwW}{fm2){BhDT^fk^y3MZ;wKg~WuXB*3eb9Pb|M%`q95KTN)uRVD1z-yQjP21iF zZ<@E!b@J`}kAM7{3!ilk$TuWug6qLWpAd?A&}kuS8}n6@76xxYeMYKUw`nu?xlfY8 z4ujx98KBtJU5kHOeW?F@kZ5p&TviaZ|=BL2Fn${O?5+Y#!( zY^Zy@bteu&sLEXJi2G~Rr%*Nv^_s+~^l;NtCb_bm?zbm>jmvM>w(Zt~ZWINv%$r>g zH!Zw%ihVq4DRuMoqqBd5;|+cENmdKkPWvJ?LkD)7!A@IK5uh(yvc!C8mRC(O)Q0kHVeF)_l!EXc;XsuKh-+r<5i*^i8ndy> z!rlWdY(liH9zT32GRQJ87Tt-vhnH*pC@=4bg1f?4HKemCN;f~ekZPg-t7)IYdUTc_ zNZX_{Nt>aGvAkHD;ll%9iK7uvmlN|6jvvoMFm;H5CD@Bqe64NZQ-4lkqvLj6KwwaJ zvpXE|JCt|0jH>wYgC9wtk^^My{*c3Ydvt%g9G!@%OOW%h#v3yT-kMc*mNdI!; z2(M`(7bZC}mbzBE3XG0>$)EI82I#q(RyE%?Y2u083xiKaKSpRF5A$ z>dtsPXFs`;KctC^8apGF+4pOCd3nFXHJ%jXU%z{2yMJ7m%JUEj9z2(FU5Db*AgArYrp|-3cRi7|QL(rs zvEr-7jU!CDAQ*%C`+~HpG6xbD(mj%IfaAcZ{42 zi2%qo9QU7PD^?8B;LxAYs8MzDv#+!bEHCC{WLP%&{b5)?SPYdWK#ANO&~T~wXgOzz zR_`&E)P!Em4lKner74n&9Sn3FTe7{_E^?~XGt0kL+9z>O5)aV74tW=l6>}&6NApdA z?ew#LE>qb1y+LID#>kwj() z9W#L-qflt9w`&%>vHq8>_AI@}Po6}RJ!@Xxn?XD`(pZe&nzD#8kW8p+z4s5~CA$fO zbe4IRz~8F6rg~jXG&^6Yc=<5Rtt`CpCe6=g?+e~$znCGvi3+Md{>0PnUN0s1yDEyA zB^Tg-Wh%^f2iBG$XI4rz1npRe%5JWs<37ZYZv2R@O_aIrJb}9uM}9b7|BVI~?_g}- zu02uIQvcU;yI72|V5qeZ>OXIL?)HEH(?-W;6x`%w504W|Yvx?C5L+f~?^gH&-q9>g zdmn{0cgAdkYd3EW{P3e1RWa`eT8&7qr|JE|!gsMD8@wCHe=2Hjb#)JeP%@fYFa$TkUA#~JS{fqv-!R?Bm|{Df}O26QK*9Iwc zM|O8Ubl3HWp{7@!jzrK~4n)WP9@w^vRxEvjDvEfkbynNJ;D%=D$|@f96zUb_kiWkP zjLFdQBDzR_!Jw(DT6gYz2pM%9q`Z{p=rm>r&**UKKetIt+m0NxLZfS zar;2;b{-rCMUc$emu8wEoOR2eiBpNl5X#z5N8oZ@L@cnShNEY4e10BtBRPUYNf0lz zG;8Q(<=QqyWF|yDr6bNqp-}VucHXT*qF*1*2nA^?JsmAl2FesmBo#OaZ3iUXIjSF- z+lGp`O`A5+)c^5N3bHm|mtK!Y1n3405-cK*eSo07|&-Mpdm9? zas0#q0wu#e201H*QJcbtBInTaL_A4;7L-r~Fa&hAPG0W^&HwczyB(seM#MTrni6Y& z*zgm~!112`w?!_8y-1}YtvRUGA^YU%fGOLyY(Zt(Ymdwa0tJD-tvg%^ttHJjoY3`M zCX>JpXo+*F%!$x-qpA=~cQ~va3aO05${~?o<7!r;e_Y!DNkllIj-8{qt2u157P*@%UYt2 z5Axd!^(MHr$y~KO9^$ZyUVjw^wENSC(&WTTM^0sAjkTvR9wsIxv9!qL7zYjnfNEI@ zNT|AcJS8HtU-#~h)*QZEk=9$Ula>oFMT%~v1AV1tESlPNkG^m6JCA+AVUK^^Qtov* zh#ZB=@$2=ds}8XD2d{h`rS#`7;bP)RI;NKkIZV$l{R#RRqlDnfG|@$4*EK{-XUGHd03KzO&wfV0VW zba7!~$*vS_+-V>Omj%eSW&op%Z@q2MVtXj5L3ZxQq$Uf>d{S^qB2$uYLfQ!?LF*Xd zYxe#31$eC1&{_lOUb}tUPM{aWmb7R`24sAj-GFzi@20qobk;oEjztol7;{U6f|PET zzsBsMkTQY+8$G2W9aqW-M6vsk4}=&`KnA4^Jj!HDa0w8j5OwHyM_?Wt3;anB6``!B z$b@d1%Lh+jX}0dvzdbJw;!>cH4Dt<kIQUI5e{pj^+sF6ZWHzd#H(UZ5>nf+_U0PwgXT zgUzMR*++Rn;kqAssf=_wm~I1XDfIc7P0QsS2LRx9una||9i>WTatK7FTHmtk}J&^Z}>f_i9I+L4q`nhj{~y z{gmCqe{-7k=KYzi)L^-QwsQ?XKVE+1T6@DV$J~Gb1+(-f4%s4Us*Mpbg zU#w8a_S3oX%{Sk~vLTgO0%(pNHL86PPiZp2MHS?^bZHLn$EubhNKh?0Z)q%2KVMyO z**8K2CrDBS8D!uij#jwiRd^ODbERXCBdRAn{o_N}+%<+|2EC-89 z3Nm#5d_K_~v^)>0R*(z+EP3K0gCD8SDWtbP-m{PM3#t7^?fuMtGJ9jTN^2^4A- z@;0PqEBN)!6CrUA%yT)loC6-M)at0pxbS*>NV>e=Hl@ z9TL<5PPPXb!6J~?3{RrYIZR}*@Sfb7EV^Hr6cgjDdS_l%o=WHF%CNO$T4g-|gYZ8g zDLmkz?uV~R4&yEg3AS2Fy96g|&|GOPL*|r)jJAr>KfdeXGr(t16;0DsOk~hk_WO_L zUzWu&Y(TE5Z%c}TQ2JUEod%ucC4J6!ZyV_Z77UGACI^vJA1Jm(0E+%{G-N#|VE)nX zB-0R}XTR<>Av+kM!E%)P6|UR2)Ne7}8i}D|=i(IWbuG071H6ZWFK|#lyU=;R{<=qO zhwZYvD+$Zshw(hgc!XxUg!@3~Tr@o%Xx1F>hRnFy-7hS*J zCt#E^8cvFm3Wg1>pW}&q`38|^!@`aocRKZHIRkM59~sArh(mHrQM#32Dt!^+BN^}D zTK|%DviB*+bNwMlFVJGePuw#C;&lr@3jVYQd}Tg$3u7N$O`Yc%5JaXSYxCncl3Fs# zlmsWGAS`U2)cQ>bm3AY3@si{Y;iU&FTZc0xA*n{*E&UP#lkNN1XemPXznWD9T2AR&5Rf^W7{L;Q+Nu zXkMjj7IC?3R7yWdH2r0)JWcqRIsr&T2gD|=Gq9&MNfK|f$V2YH|y8vm$CgIiXT(0I81Sd0#AZB3FQ=Je&Dlrew>O2c7> zhBgE&%F>DE=GU4(S$|@iNpdTNVFK(8ay7fyVU?twJ0>M!88Ec8Bz5l5Dq+E%{eT4Sevn+9kF4$0j;+aL>WO>J-)?d#LGM<|2!uloV8|9 zaM;M%o`VZ-06DO7z+o)F2C6nqd<=i=+z}KT{_Y{Ap83KV# znAwNe>wNwYg5^{snL^(w<*Bd)snuHJVnVt`IU)H~_9ab+)!_%{{m)BhFPQg7EF}ey z9Ey254NY3@*E31-lyo)QyxV@c^I5LRJwrpOkmJEcl)<3EW^bmmpYCyv?Cp2ckI&n{ zK0X41;V250C#9wP(UgS`S!u(BFExT9Qy(sfSh1I}&?5x64(*PLS37}9wjrOpt-Nsm zJbQ3B!?qn|nz#;=NA9=D})wgw~3KsMbDvAAA~V z5R^LER3?=L4<|56m><8(cn`}!`|NUNwWLc#6}{JtmlPe@u(gXlor2%}p~@v6_F;@G zO`nnG82drgYa+c+MNB{HdzX5w08=IONf$5vB3c0XRB9|&y0*cXjX)N>&+mLn5fx6q zjAO#*Tvq!DC~_*JGT=|`_>D+*s+kUnDNiiD=0;LUEmSjg^rZImep@t=Wn~ygOgtL2 zabw}22fq8LgQJYcMs=ka@c&dbR*u$`1x8N~8s#6>dRl<(XsR5ksiYqYGmOqM+eXQ* z;oywW3-Zt4BzWjy+B%IQW`Y7YZX9%yB76R_fhDA)0^EKGJRv@<>2~n89_LYsHO2wB zn(R+&8YMj{fj-UD#a90_r$d;L_>L%gD8e_*uS={-yv8 zBGC9aDTb~d`mIG}m3uo5ew_3X4U(yoVQE z0T2(b6Cf^GU`6{chNyRuqIj;1*xdmjm%v8f5lf(O1#e?qBf@(FCv8%{uxI5hR}iMN zle0@Ib?|-?xJTeo`k#3Am>^$7u7QV*Ee`(=Efx#OHn8-Y`%b)dtGC;ell|gmxgAdG z;$YO7L)4II{X*ay1iJLm_lcu!bv7dWV2hyU)vSqoLR`33C*!$t2K{r9j z6X#rMCU}G}(Qr4SilkVuAc7VDG$+X#jbbeN0(0-l6O+)%P>5?PSuzGX<`%g$O+nvW zr3C8mk;bDp2Vj)UY%2K^-fRMcR#Qn>*FW@XcvN0Zm7Q1iNBqq7Od(#3H^jx~ib~Q{ zJhWGuuJN!B5jGTb^g50$OUV5MBX|p)+l#9=5-X-HIC1o-*vN*GGiIgu+JUO}vzyYi z_{q>S-L71|+{5ayU)Y(ULxxDBAh8yS>kTS40hkD+=?rE#_*SjalwopFf5Ckm5}JAX zB6HJu4}QFW$dEU%l~D}%z=QB%DA@3KbFbwT_s2v(pzZn8%Gy338XwV@JgrhmiS#wF!5%{CMJfpQ^4p zWq(^-2Iz2rS%8j3h-4ozqBrS*S6Lu^b9sy4YJT8jAd@&6Lun{CXfpf=%hesAE?oNFs zHO{>&J@2yd!q>?@f07n5mZdPh^U zNdO%45!?uoJ@ArDlEdO9I$fW!dw8RgivMHld)Mo}X24j?JmPr(+O-hp>64Dw!TCb2 z#u2}&2BR#K5A z!EuVP;dOOS%E|;oN%qJxi1Zj*b(-R;wmkj_hkuU`AL=uej(%!cZggIMtqacu|0EOW zu-28qj2`bi5x#3z(2+@~Z7g_OA&I|IvP-sR% zX&u@e=_r)BRMZ9M{r$nfhO^*8v9zbe+s(G#-^;ex^Awm94bXrb(VU@603VddNdckU z_<86NiZxJlub0X8rFTy3@BzD#$0$Xm#3jI=97XKL zMvp$m!+7<+5R5m8#<`|#E<6VXJt)#c|N4)SGD;89>WY=2NOJ*voPsha`_p2IB8F6+ zpuf>xNW~tqe3{vX26bjYwX9Wg=B21g&}fy@vbK-d`)HZwe=2`){(PDdUGRkb#fRU8 z5`$y@11(6iNMjds{^Dq$@oHQd{U=F&MTtzRx>ek7lqXjL0A zGpyIp!6j98WT9xvr?6dl(}??MCW~&3zy2;1Nv6x)X0Hj&gZ*kxz_$u8z@B#d) zIo;hqLPaESo|({vBYDVxVgE#aH0O+-z~4cT+hUognqs~;r$*vg|!`#6Hs zZ$q}3q$P5!uoG;jScoQbm8i1lhdW+HtwQFz?sQ#*8dQq+Pp1ppD(YEk*2Oj_#x45o zn*bzzF}&g(t!q%i#-nw9+O~!@jAkq((t-Jx*nul-+UnaM#Ia2#3(zSHY)}Y!l%#$S z)ui-3G4W~IQl}vYmwX+8|CSmbS$=6?+onf|bBh-6f@VsodozjdHKnR}0YV9h^u(Wm zCKt{TLQ<1sTRrBZ!Ps)JYSdB^IdcWoBwjs(zSsU}y8?6{N1T&sHBwOw8|mTUvDewp z*M!(s6jIZ>;b|S|=$_oO8U$IsjvH5T*f|tTG_?qt@8Pt41J~6tP)m{Yvr-Bn#$`fw z<YjKMRmAwrhbt5<6(A^vVq8jYJHArZ;0_7!)IejQTt zk@=0qE;x(oOla>lQY&%cYigyCASmnElyTgt%xr{Xi3AH+K=fE_coQ1UigRWq)8F^6 z_XA@#r|a6RDH3=1Bn#{N=BBMX{WkgS<1V?V=-16bl5|LPNDzyGfq^#UbYaxkKB5*< z+J|(*=u)KF#PfJGO!H|v&Z6KDpa*zfv@R6l{Z=(@=#Lc|mvOz{FX)+p!936LoU?BdJ=%2c+y{#*$>&&#{3e~&UK3Pad#Gac zm^BfKRJvnuOeC)yyJ$s&cazCFu=wre*l!zGSNli1IGvh^? zDf)NPE4AI|!A?Ps4gYDju+K2V69j#86_G2~q}31|4%uTBY||8?nus)&9wg61ueiW7 zdiIg(@r}%FUgoXJ{&B>gbKftT<{hKZ)cJfo^hhs*cArbleKe|k?4w=Y$&ck8E;t_1 zRS30KrJr{8m=u@&{tp3eXd4EZ=)YutNk*EnI5`^RxLh*PnLI^qCe^0q(CO=BKs#_l z!}dcMlsND0>Jn+{Fonrr2IfaGyhZ@f<hl?K^Rjl?A(JvXi7G@Gapw>MH^)k}MK?|#WCGb2Oq-s=Vy)GHj+#05Bm;}#>SH~OMup&3i#ZJ9ln0N79JKnoTAWKCIpkwSMZ!WTdQiS=)kYa0eM6 zWD@8&OHIzqCbnp<{GTQ(-jgHJ`8$Y>0RFuK-M>*Jb;(@y=RR6TOk83KbG;Y@q9-hn zez=ohUh*&?uEI>!J^An^GDjnWYFNRgM+_BpF^ZJn>haAL0Z(`QPgvCrry+bck+_f@ z-f~t)1`A}Ij80IJJiPN>R9M=kN5E5DsYKV+a=LtsEO>JA4h&kcYTHdl?s&Bl$ck0o zVSN;gzDzYELf*|ay8yh2UM&nOh7{opkq>0;ZK?43Jx~918<5eGogh={0dUwZmjm zydaA5TY$9!sz`iLFISIm$CHq?CCAXCUqqd8?^b<-GT9C0Lc5C0YNRu;LyIONIA+{U zq}a#mDJD7$^L^j^sfYxaub4)>v)S6)E~pOr51njCU;Cu_&mlJ>j@>dA4F$0qzIKy{ zDf(Lf!<|mxKW0IamINyTY$I>m8f)XgWts4Rp)7YXFZbWZF#1Lc%H+WmJtHp;z;y4d3hC($14T6+940PmvWZ22TI@#Nk zh%t`XX}y2k#+xg4(OH0?3+t{ayaBoM90w}TPg;bbi1wRt;AjHJXs%hiRuVZ(Nr)yy zW1lN4h5VJynHzR|Zi-$DS7tRkoJ)Oif)l{}=s{Ic1lW#_=R7O4m1v942yDQMr8J(E zTE!ERT3i*yB*vQhjZFdu?cymk))0@M>WklY6)svf09Gy40NXulPgDc)7m)k}(?wD! z%bE&pLu3goz`l956S9CXWu4I7YrN^gb0U?$N01dn@2b`rV3hB?^>mCnjO0x0}RG+q8_vR&?mp*Uq2T80*kFS<@_u zo}6~fN`x%z@RawDnuA1`jQ#-hUwmDx4hjxfk0e5EWBZ<5bT3ZCIb5+w=4GanouSJa zSs(Kif8z};;Q7egR2+S%|B!`5Rot^ZwlvFeQU7WA) zD2U=!d6u|I3-jViqHgt4@+DIdoJ8!{c5$}(vJ{M_N0oqrsm%tk zR|#2Y*GRE=|9}|}zl(~2RF8#|JiF{k;IT>|z*?pT68UN*H8Dk%Yk9Tm-dIJYeg@m5U_y9rQYSl`?i%?A;LW>q609p)7 zL{LuDYNv|SG`7}8s@@@{HNDywq3?X%9>M`GP zrS<+G{(T%$X0jbEYJ5o-uUw0;NfoS)ugSP{>Cu@gK@k5&x?mMM4E2^ihO`vcg1=KYyT!1Y5+Z0O zfePYlE(5zl0z2NN>IyoN&2+5c14#E1jNS_`L`js-PbX5EvS8BN%cx?=mNhgILIr%T zt+<2PBH85vP%yMytM?UM}R=)iJpypJ0H33;QId_w(+o{{9{%}fI&I|B`P{{ zY&d+c4n*;&s{iXK&bVgQ8_^8s)08GM{|@dfD_r6Z6K5gi555{Qiwbny$yuVQ<_tkK zZ>I7$>0P&u*%mfVmoL%TMjZ*r8zNg&+of3us61I#Po2>xg zd+V>$9Zd*%OOcX~I>GT{ntONKL(`*00{SW~vJ#8K{-13ab?8Hjp^FczYTQ(v!%Y93 zb~s->U&~Qy$x@lTL<8p*l{N~+_8IW0R>9Ad#e|V( zz5n?&RbU=x>^b``zKSx9t=#{%orSl@C@bn7Q7tGp4m02&IZrS{H=?&Vl+mvyvM}qQ z4FHV%W1X8KX{P=TdV$W+g<;TjpZ#xw0SSVH=UmI{OA#rS+z#dUe-FfwmNgui5sea# z158YWVh*~vAID$r{moSwz#qAB6#Bo4DHqQLZpj=n*8(T0;7=mo;=mh9U&47w&#U;-ZQGwuW@mCXRkF%%^k z*RNmKTm{Q03QbliB4jP4SRM!`2_7l)t@G+riZ`?FwQp~m2tyAnoX$Xm6jh@}{nPzv zP;l_T6kK|8ByNkyDtBKW17-~uck=)E;#2|Nd$PzuZ^6pq^RF(vd%7y@o#puPIue0s z1GiVp2v76irdLMlEKPyk$=+4eNu8Q&3<N7N!vfFPT6PbrYHgvM+YwwEs|y$3<4<4R2s0BSAtOwOj|IZeG7rI320hFYk%^h zKK0;ag3)8g@^kjJOs?8{MVE;35w^EvZQ2728o26yIcuLqI#Uoy=PeikN{R~Vxj}6^ zlHH>v>Qm6jsG^+vEOjUHQvJ~WisFJfTl-9*n(B+>J&~hj!8nY%=q?yfewjkrMflj4 z)JP9K9_-4Rgo0(k;}b77rHpXlj>oK3a%PqEJTlOTM5;~Ow)+|VHR^o)Yelz}Ld6q7 z9zg#xIz}JN^Esqh@-#Ed(RUW+GV%c=F1a*EQ_g70S$i?W4d0*+OnV-?^$fCy%C?*U zGd^bs4t3THW&!?hI5!=J&qnpQygUvRQHNZbx+RjPE4%K(U`27wuUl7_kb&}H8+pmr zMEfZ;05w^rl}JvN%2nNa8A-Tj#(mM=&ulU#O4MW7fWBHvgemguDPn1n7PX|!TLJ`s zjoK>8?BS;*;P#yd@8Yetj>{DeImG>w0FCqr2mI?li}dR4IXvHK1Gpe4i`yAy#wvtA zJ=WDC-su{$T!<0Rim8r8G@E#nh+Rh%&Ex2JJ|Lr-c}Q$gRM|N2>MX__xF^kb$kW=JLu-x58uk(>9U5zD$2 z3wMJ^SEABfXPEe*ieOOt!xTUClsH_AfRmqG9=&GjeKy8lg z@C{tHE+}Z6wJ7jwXcwMINPjE129MIJdRX?GFTw*=m!WPbRyCE}fJ)PGhW80?!(T2L z^KP8Cu68UsxzNbx?E+afJho88n@FglDHP%mcIY~!BmrV?nYZ@ap1US&oA?QP(L&l2 z^o)l_`NYs;@-Ne+dty}5?fH)kxdso_c{p3tx;y`AX+xq5Y4hWo>E~A^tC0}FZRG`v^{Hb+{uD)`s%c{2RG6uZRK$A#fFREW_bQPeLGmyoyy|9vJlci)fSYk7UD zJM_m=#f6IOgBcqo6mi}*%qbE~Ypf`55Lpr-pSdOr_HH7Y71=vd)PrDJCBEp?3c2DK z%FRv3j~@qr=uu{U(b{Hj|8}iTpE;ugAW2=H9%*)7jwm4PirygYQ_E82_IE4and^2b zmFkjriZaYq0yuDri{WQ3*24RQ(c(+#S1Fp~N3Clqaut!l(s88EHfVn;((5k$ySO(z z!(MM98wc@=HsqTqV{&f(1L`zHQ}hIqxR9?zjdkwcmV%oFKUGvV1k28t)iLcBswo2_ zp67LRCZh-xq}=P|aMXfOv*dWJ7Y}gj(l?8G zi8j=60a*%?>%P-diTZ5wOScBLZy`S+he-4~2@9$HO4d(0y6v)%`&ABUUqNJ*kw(2dC`;lEKfzh>*%vO$F4r{Hm6BcxU^62@!MRc>5$O!gS(uP;g(`0%4YFU9eJa9>le z4m-?l5-~^BSt&9ED1x%k9ICra+c)l`5)Q%89`B4pF~PIWue#^-oSWCK57%7qV)Luo zwE90Ij6RbYbaizt$fK$*3uCI@zbjdan4_{Okt^Hr(IxfauI1CI_@ z8Rk|_KyPXAvF!)y$pVBc)o{gnq!#j?{oe5Tpz2J|Ex5^~kC(MF`TQxmK21kjlKHux zp%2*JY0}dqG?&VM`=XJfL{a#+guOmJ%um5QofQ z@tA*Z#bR02Q%bMI%h?#b(h%Vfi2pLtEu0VPqdo+9s(z&n@MjZs|M_>J09gzj@V1tQb^!J!2D70KsM4{p*Fy% z;(>d-DKMqs5Upxz74#Hs-4s#3U+vvUVX#pox+t@FTM=R&$ngQlu64N`z)c(e0B@VG z#VX1AnBgG?7UXQ4NgRMSE^GC;OI-w2k7Gw`E?N;0cbP zQbf~u>bcm2DdTz5av}wf^0k+XLy-%^gyH^Gy z0MScxo`AW2m~&QL($cb=tzFBKujt+H^y373pH)9?`r&CS=p`AkKca}k@R6)mIsAckE)s)zt~ID$$j?E{MvFlqH{`p zNoA$JzkhwZEWKZlibbsKs??o)U=u*_`&oE!kNdA2@~BHXYfL_pG?z8jOCDaGaemL<*LUI2H40e%a{W zJ2}~C?bY=gIH>LhFS%8lZ==Vq;rd^T6B?zcwj=+8>_*HiLe2bFIr{UqH#knPFwV|m z^hncQMf>wE|NS5B<@40Pil@}vY2B&Sf0M$c3Dr{Vq9zT5=L+Fa5Hxa6nH`WVo4}dvMa~OQESiFe)l>7)Jib{t5B}&S$Dymky zly+s&7I`ljDSb8h+|+D~N#sP_e z;8k#5VgV&Fbk3G%wGvSXX2N5DgA3*-BR(WhYr5r@S5}nej;3dk<^*a>2bB?hulNPT z(FDZ}HI;Xr)Jzmgf_+IF0_H07kMtT=L=GDnqhl~wPF4p<%mZ(aX8>kCS-2U;*Mx_M z%ZvHGld-{-cj0Jogl*wxW(Zm(4p+z!5KU_N1XF3P^!6U+IVAdxEZPedj}vu_Yd?PE za7we$g|~`|D&U}VIp`wv;Ct@Jg|l=p6d9gSDF>s_JQ|N)+3@$otx4cAD!ppB77oR()b7vcQra+jq~$A9(Ztx?Ab?=$51_Znx@BfP!?ycGmuv)~YA= zi>;jAu69JsB~l6?Pdz%%)X%(o`Fx4NG`F&BcKpp&$tiAN6cZ>`k$L}x(I$nF3A|hy z3Ox9K-AVg{z=))lNU5$9^AL75BT%j=erq4C*SGiT3x&ggeu~FyE1v%=3fSFPO7)C)PY z_$Y82%&P;LaM0Tg9Nr>wcXA--e|4wpQMc(U8itILLNTNkjf4zfN~&cZF=mB@T+^nP znJJebqtIA$VcR$Csvp1*=YAx4!0&KejhK5QTj@B%7Q#;E41jQ^ZCxOv0o{ zl;b@0XRU$*O#=5OJ-RNy)Y%Tf5W8Z)TNz43IW*w!>-mXPCzniJ#Oy#kWwEry!iGqH zkN*Yp!8p(S7t0Esd-S|=eJ%g{GrtUbSAMP2Zq}`zlP4)uP8YMc+D>h%w^L2_j`7ILx`b7h(id8~ zTLDuR7Al@?o{3toXx??6zA;sZEgGGh6sKB35#!y`KX@$0g!D_cRMVMX81UFdpn&3@ zP*`oHW~6i1-EJPNa)MvxBYK)lMdh%ix1C|1T=nyFnkbzay}wj;)TsMsbfLQs3Vmp)r`Bwm>p`PG4T6z{<_q0bS5`$fxfk9+SD4+ zr?6veghy*C9}?Y(`WC9HVGM0y>H{*yYE^TBR*Zym7rDgFClwWLyY>;N#Rp8R_Navz zX%t7(I~>R>Q$PfdYfvz2+~p1fZ0)#)jh#&PriJtrHc;?;+*7yrN}X-+j8?{FvpQRK z`^p$*TRKZ(SQ%mICbum>X^9usILJW`&;*>Hd~qp-IT-m?r30N8v17xFKUSCIdHMM0 z(GM~;D54k2jmZ#oB9@UVn7-29SZ%}SRS=e3e3~6Dm>OXD6e^xAJtuUTo3-oxNf4OQ z@oM!yK3JVhY;yO#u)s4uPMG%AG8n|Z*+a9CobM+Z&dSQjgVPXLhiuu<6a{=7o+l}j z)k2|HXAKER?D|4FN=cg3}-Y3W)12WZbYm*Ht2`#CE z%t?~SuluQea4zU+0$PJ7teTez&KfdgE5rgJW~r0myYCsXiTwDjru>KE>md{vs_fSfY z4ic>Iw#eB)A2+XAGuk^w9n!oS$%|>yBn}XV zTCsGJW3#rs50D#o{#^k);#eQ%X{yJoAKWXI` zVtkzQo;Fcz()jP!P?^Xt!=G3pGEf9tSO}6E+TVHn-%AvM@~FX1YNHQ3cOVs&j2`Wes={!vSW@`pno~`d=D=V*? zWB-rMM1}ATkBy~;=;So%n)0Op`%d0_JpHT)d?~~J{X)r~{6r$Y`Y`q)>q?%aUaLt8 zlCzeeUKE;zI19%TP@eJaWF<_@3j}1FOH1^3lIcd05y?S5o20rS0={a=JGpyu;ZBI= zSXN6s?~GCIyW8d;`1OnIOye&qd0fH{ z3H(EA21sL38LED3Y|GN!`dSq>c5&@GjT|u|Vld~4?Z*dfKG9qprvZkY4L=|3 zuWnHqe%5J7to@Q|>}FqTbRy4rb2d`T8aM`n#J{61yniT=ADx&aL_A#j+ZBT5NI;vH zrOwc=_Sz&D8_)UYUr*aMm=rJW0yEFg2&Ke;?B?_5{Syv1cpxh@W^i#B-%avS*TWaw ziUAWEO0WP`kSQ>KC zAd~I#bk%=-c#|p=7#Wu@kEGg~umN4bdSLefT-|_7b*A0a(#2-U%(EV`LZJcNj^wse zSJB0FGFdhs7D^sLZuc(&x`$8?hzFh9%Jbwfie3%I%bR|aaslo`j=jSEg6d=&jMfu; zt^mq)D7OF=Y^dl&k_wG@`EHr|aohil$R$J_5gYPK6IM7sria(N;O^i}+KwO&Vkslz zRvS8-PUU=xn7>nCH+;5YWyU;@7!``Pb3mp#@Td!fKf#h{uVq3=H0V4eF-QHoa>;7e zuJ3|e0l8@)Ge?Qye=-7Fc&L5g3^7x@Fhq|0_)?^H3x|Ty zur)#qN9551A+nxHK$S9H+BP_t2^9m>_oc(fh?En{WRVCCGnz;6CG#6KG$a8YZSl|a z4QFyl*1i}f)M_=P2$!%-KRIfug0G#*jf#*mts_)+eNP_RE7-riu?tK&L}J9Ml>!J1A*ha> zHwmDDox6ea0sdEe=N^`0-uC-uV`dCvp0O{b2M3cgdP3}TT$u^@Hn`Bo>w#^uf zUAT!!gtSK?B^5Cyk!Y07kT6OqinKwg*5|wm^UnLO=UDGr>yP!vavYE6Fm>Pebzj%- zcmB@v`~9BZQ>49FWko<^mP-HaHNXebKszb$o2HdbodR0T$GZZIrhwm+j-B2!ul}o6sWAt~FXP`B zLttfb`MwfSpMP$2dFle9cHi0bLr;cBUH?#qQoz_7s#L#Ee=^1mXkgFE+GX;10rViI`t`~0(kA zmh!YGnYJ~Fr*P?e5Z*f@sYTF8_i5{-u+}sfuvxB@h>fpyUln7%@`r;QK*00J6waueDY&R?11aDx>4vMq zNk9x-$&h$e7@VD0UsAaIU|VNz6?BGY6V8b(xjvmRK_52wybED1e);R;I0PI99kUZf z1oA~WLk0-Prge6#Sb#y$skdC6-<=Lga>O9$V1o_OG@8QX^+{{zBPzNL&ASGQ91$D3 zxTQwZf*%5~ASNy_Jj3y{{nySp+ndtou=*Uf%USt(w5{))z6yG=bE(>TX3rUWteroy zP90G@ZfkGgbgO60y;UV{kHlUtD82ZTj7h^QZ);i}`dqpqjvnnLi=PZoHr^&@lKX1{ zOrVZ(q4jikuW3KsRzN_l#MaSo8cd+b&3Qt260=Q%721BF0!Tn2q4(3rFJMYzCY5Ik zJ-vX(UwC6-H?nYY%E!_I#x53*iP#3R=bsP9xM!QF*Qwi8I^D>nq2~{A4Wg5;gAdw& zba{Gl8+Uq6ps$Qyy4nuImQI`(&@r;fb*jibx~Y+RP6>38o=(K+2kHYIs=rAGI%)60M5gi-FJ}B(=SXSKuHlq07Tt#Z+iJOUV-(aP^)nAi`LJ%R&Y~LbhRyLby~?1dC1RGnMIii6B3~92FKds z>xq3g_pVEmf4TL8wKWKmP(2#yGjSH@YgITExt#QL9*Tv6xYA$|Z-Za5cwI?jF`(60 ztgXtZi2c*RZbwUUC~^-D;u}aP)EP*lgq7rEGBDk;kOYx_@|`5)!PwTxUPMcf0qgsS z206ft-LD5ISEQWr_ZYCF(v*Y#}POD2AMsr8sJCz;ECNY zeeEsME&-{5xw%^3dRtuSKte4UUaO>D-VU3H{DRZ~jvazCh`?NwXL1mx(FNiQ*(ft9 zH`^h75R`NBlj7ol(G!$vA)pLqq*?(_YVVpzu@zwUdQI_Lp!#+R(KZ%Y3&wVGL%eV8 zUAJysT_Ge|X*9X#O|0AW4&C=IRW@+#(hf;J8VXCIx%+^+8?smP&bCXRv>O&Z`0oyAWSlaUHZ4A8CQ*Niufr8hms7DhQ3Yp8c%1Di+Pr? zEz+$3FTCCfW7kMVD!X4)3W%+f0S4BR|B}k5Z=ybpfb(#Ao1by$S;)8_9QDesc*=U> z{CW>{EIVh!u|(%AAp`|QNuu(WXZ<_y0Lg{!^>GGsX^%v;Z>o|jKkrBd>rO{57X!%o zlPU%6YZU0N6bx^(IrJF(aZL9MB#;$QX;?=RcY^*oDUp%BB&)sjSgjcFn=zU zl&$V$-B1- zsDNL*7>Hm8_XkZ;JMMK=0#rb!QJ?!X#6aKrM+JlXK9Dq^W-!z*5XutY1h9Of;Fb=7`TG}sY zRad9M!-$6F6wbUBb<2MGi0z->C-h*n_km~(M6`!sd9$-jk0w8+DQRoVdO+th)JC*} z4eDb+2Ou8qv$re6#Dg_JNUdoypw(ZRucSRrWWX1Q%MoPFbHD)Jtr5Js2KASm2!!Q} z_S1uH#$Nw`IuF+6Mg_A=BCMZ>-4kA=vqQ4BVI|HPAs^MF}y1H=+5jl)U%b~Y?&PdVmA)2a3A4bgl zpQV69A!fb8m;DzU-tPQ=E9C9``h!XGK$bM)CM`j3Nda!&Cr^;89bHHHY*)f$(7_ak zmQ;uU($<{OR=sb2-z$vML$qLyg#j=Ks6xhaatWg-F{x`nC+cPS@9aq^LB*WemYXiB zAd)&98J`C$U3%UQzHB;XUgJ_(uwZbZ7x=6PzlC_ z5>?77{86dRNASO1Z|m4`Bw8&H@JB{QN)UVG9IlT~y3CLRTn>C+ z6Wp9FOJ(rd?<-pPh){(%METs4)X|zv9e*(Ma<@Z2`;Jm* z`-9b2Ux0ScxH>l=V?R$>BM>VBHU}9-BwRXpm&mLM!D2(3@is!0a)K3dsp9w~$&pM! zlMF=6e{mU>B%Qo*uTMSOpf6!c6m%@w&_`b@dsh3p957hVI}d|(`q;sv?wsmd#(hLo`x}x*p#?R_;!z?f06_Ff+(q_;*|Q-F7`^+&tv)Y)xbxQ{rqmQ#qS1hT$)i^(6ts=no66 zXw!J@^|$1gw4@0fC9~{UXEI%oN=$Rf&qununGqQ}<64L2yTJ%vH(pWj=p&X^b389mSc-3jFZPJlbyW+Rdoq!*^_@A>+Iv19f zmiD$Lh<$wEqYv3Z0Tv#!HZ@AV>2Uq8;+K%PF$@iiGQpTB&$OORP7&{)V@ z|J?B`7UA<~uYC~gR+neXL#Ye`jFG`3>6^Sj8@Qowb2IsDt)SkVc^j#OxwGBS^;|09 z*Rh8gCKeg>*t+tIO3$`#-&tEJG=k(N(?D1WkLbayKH;k?5rI=w>5C3qoMiv;+Ldda zsotcGVzZw4M8&>iJh*|VhZ$%^9~$6?_tw_IJM3nDo&TS+999}L#%MK68b8_VC0Q^U{vX4Jm@gEdh_ zJT}80lq5*``PsYV_Dto$(Ej_mDJdyfy=wkdH(IRj^-E&dV++Z(*H^M9134RH>c|+5 z_Z{HtG>uLGVG|*xbUug{zb-$odvRL#8;yLRRQoH~4vCcmzy*CYTf3};gb`03I1S%B zbxDr$l3vjx^Ujf8~r2m4~#+BF&TWOX8r%WlK-5 z0H$Oi-8i)V@hza#lsppVi2P72mh0Grh4VToZ!Uj5-(a=8u2M16k7oA$y28uk3l^2A8)R36x0Rc z_6ta6T%i<=5QRbvsxvSRc%AW5OrHJ|guQCca@45eQXshv__L7qG}wu^fhq^>CEW&! zV+P2z=vsyghdWXgL&wCpV|y7rDRKNl zEE-?w21`P<9Q;fMl?d4?3bW|w(mu}ou?IpDhjV$Iu#BgvcBEKEerS%6uJbixF+Y+4 z4;VBMzy|ZdH^mbaw4@8E(3m%9z=TMw1|X0bH^MEjUI10RTkL)JoR)k|YWtY-14F9q zi*dqIk7TE!|J}Un%r4Flxhje#k@AY4P z{<$>9(HAe;$rWgccQaN>AY;jA2oI_nQeKgE!3)c;aMF;Hj>;gMiJKi3w%rU2@`Cr( zjboGR9AQnd98a4gXz_)`b*$s8yKBqHdV~lDl6;Zz5-H*tAoy2zPOogePZ|}l6>2a9JxtIDCt+lR4d-SK*K1Y`Nejd7~>?I4%cl(YHZ+K zmmlwx4ZuFqsX{J>idND{)ok$}D3O45etNQ!*(jeD6_}QH=10zV9_w^5A(8|h8K9g3 z01x48smeV6`J!oZ3w7GD;9ej|x+X{Zmg&f&?vgogJYTU-q9eZFz?AJJThgl%m#JFR z_Jv*;alDpM>$1=Q`T_=6OuV$Ozi3O%?1xvL{AtJfclm36Dx+xX^Vd%A^4E$%brRyJ z529uKk=-%_&5HbcWAlnB=TRL0Yu7byV699Ps= zlBd(rF4md^R0qkd48Ks^4*8-nup9e9=ddLs`K4gO6aWM)#FADVAs;Nf@>JWyJ)fqHdgJ(a)C9!h=MAcYGwl*!De+(N}T zN%BP^oG3+DrGX-dICAn{207uCy&JY+tt&Z)NQm{!?@XVIRWzldI7GT;)DnIyHQ*MJ zK%LrEw-_FYBA4K0LvLu`ihnQ9zN5RP{(-+qswg&(h**@J^J@NAhMvvU!@B)z*r-nP zXWZNNyVW?0E7@P^H?SI6bUb^hp>duBGs-CZ32H9cUE?%~OiYq}e_Tq>&4UTsvi_u10?KCP~teEllz^=Oh&mkUK!5qem8 z5!JSX5}q%Tw$I+vf1xO_DPDodV1%EapC(8MmODGQZR?;eWYD3)j4h=JqWqdgyV>-3 z(&ynXEvf=Iv$Ep}$i^U({>NBOf`ZBWP%8hyNV_A@jI-GBqA_kWKXF=S5;l@cORw}d z28@-FBWzn=l{kUPiBGv~`GyOv|KlC;TTz1)7#NcqQShCQ?m7JMLx}Czv95Q-+QNT^ z&BjJPP_=ikfz)i@?z?kmNLiV?;AON{T3T9a2o%|`1U|-kqI}D$H$xDm#v>l?yX5)z z<}`?lW5Sf$PxY1z1`iVfAz4|+jmd!`VIn2LbE)gglJ^H#iz{3W*D;$tBt#+a^fY## zsfQ&dmNdq6dwIsz)^>+R^U%GPwI~%VwX#P<5_q-Vsy7SGG5#F=?BJ@}8i}i!)k*~_ z-_$$e_o8PD_Mkb3G2}osl?h)`&2rNuv61den$N~bZAvNNx$gZXolU-5^Rw6jsV?34`#z`AYUM^PLps4|$?HIF$f2U1L?&F# zPI!MM3pVp^7?KP!L~4gKf3 zx_irw;`idY?)@2ccFkSK1!)&A9nY5&M#zuJuVb%gLMd{{7-KJXwH3hwORB#!qkk z$=`hC_kX?!l-%4rd)l&j^CHE42H2r1V8<4-7r!J7dH?wjoaXg389r>7Z(LDQ&66ZC zHH4R%g_5th$@?4Zz0u3gU5m_909ME4CCMB9_QS#2_Djx<(MUGOBD8aV-~OVN8unWw zAD+Dk94W3RT|4mp>brE}2A5C>E&S7Q{P-{tdV%s^FQ&ba=bZ4LZrpd@{lsmbk}$2c zM#W}jLzXnLc=vY77hVW_-Xu$Gf5{O=s6*LO+~oZ;@@~T-YTd5gvtotxt$T4;d%YHl zlC|!i-Vq!B=E$#=S?erK zyrtK&WnR3uJl_2q@Q{}`2#W<;^!@eME_6!Ji+*il_vOHi;*%(bqpMzDo{Ggz;q@hw z!*P$RK9ud_ZTO46y|PMf+!Ta%@7^PjYguM~ty{Ma<2)^d)wCQn>PVcH@>x0$8Y+)s zaE$$%dC8QWcU-Ad_o-&H;YV*diEOIvpp}zpedAKbk|<(x@=P;cd1*x8+sg#!SIJzt z#aa0EsUTxr+RaS!@1LIc3?PHJo-x+i{83qJFqegY@)3)XFUslRxwN}y!m+QC?d(ST z7MOi|d++C8Mm+wP*1AaNdEv+BRa6|3e;8&NTcM~hGtVBYGhJ;j_hoR5ufKn{f{dmS zK8Q%hS4|!T?oxiC&3qX%D=dkK1;i?s6n8iMO!t=Le6j0-LVMabw62o^+iU=Jq-|nC z?LE-M&8>~Tlk?B@WOhhy?sO?&!xPpd0iCTYO12(2E<}H#o6oq&0mpxIQm%B$-Bp2s zL$q|A=7q;E-g~+JLK7cGhPKfUsEfOH&dJvOc9u2I;htMAue_!$B0?!-9{%2tBlvQe ziF_3s-}*3ZH!1nqJomDI3rJ{as7(K>s_o^DK++nZZofsFm7QSUqApJGtP&ez(BV=P zi!;uJkS*GPXxG7fJFA67Hg6*DY&WC6Y2b!3EUs>N@e|xYj>hKa z$N4Y%BLEwAtLA}x)_%?Y`sH;K*v;3^FFYtnNjpFdL=*(rC>_wD7k;IuPM{ZRSt+yR zTL~J2P<}R;750Du{#%qwHmncO-gBZzQ!8O-EsJ((G=Urs2^BCqaUB$LHB4{GIwdie ztI+zlVPuZ`IU`)STW(DL%bL{lX=Q$Xz6^n3{>E0CZpF-4WW@kF`o+$u8-eb(bJo$w zYM$+uZhi+><2GB~pQdq(&AOR-Z)nLzWmd+YG3(Z>x&7vJSl9?8tJlePuAg$7`Gt1x zp^!|%3&=SpqH5ZtPG6ZBu&Txj;BpjF0_k=`Uv>n&L68!?M&WM60(?8eL%-JFT2qDv&khKd? z4euCp(1M$g0#bMpcQT=jcKG>u6)6k3?)?_5)|t8>Ej``A(sBbLxM18##zKU~K5=6T zkPIi2z9iK5SNd_nFrNM~s)$O+@xw*h+S-~R32dA3R34PvZGuO{sC?P}G?S3fBnOB4O2XnPPwU?qv1x2|DGANJuo(iyYqMOTwI; zr21!k?CG(|Ed3;nlG^3d{u9|Ml-kO zUgVSk;vuC8;u_|l1Y%f$)3Bf4r&)FljpHAT*AN!QBZj?!>3c9eNE7$X?Ow3{cXo;T z_=*pE_v&?YWVO``GDnlSJy*Uswz4|4)Nb06H2J>oHO4iv|$l*=AK(f@Rh^id_`!$QcI0Uj9jYV7gcIZnYK-9eX{~L?@0iA z4t1_(M~Zly(j6^3%Rks&C`C!^S$WZLhs7^)-e>iZ!8%p8KJ2$bZlOIKzeoXSD=NxP zaHy@km1JjQv!B6?VW=-})}R9sS&sDUvav-|1^1dMZb|fX&LQ0Q%Q;E9N_r^%ozkK& zLwb;oidU{m^jt5VhaD?-e0QJhniwEyFK*x1(OX7qO!ZKU6qc zowN3*Zc~%!t9vFsmpU<~`afm~r_2<(JM3u}llU>ihp$!1td+%CQdX$}lFr98R)ESm zx(>NNh%+MRg^UVr*vaift5yJF)=Txy5hyBg&l9NTN(cNU*1jaL);o*@2ol1z;}D+QcK^w>V5YO zO+P++AnX3rvDZ3`C#$yt+-lKwtyNsd{-~(Nff$@O>d|fhA4P%@S2#Y`Yw&CPkkXH5 zSOly~-C!0~*fC|u412#tW|vNT={2zEk9WGS`)ofw3n10QD^gSO7ULaXlpQup`fcZx z@2sAE9LligDH_Qz>=vj7CzsB>a`&NO+Sh0I-RU>S_F4zKMrqw{__OF>7#yxmi~G6T zqCPV(#3k<2&h@&{#?9b)=ND$TzIXhfW~fiTY1-E}{0EdKrfyF@djNy`kia2>8+197 z<9C0MmV?v04<5`(>*oE4m3?+M`^t1n3N4YeDmj5# zF5%JGlug>5)B>9~K7TZp6_6#JpfZk~SyT5`E4#)LUXg&aoeX9^A2NwJcX}%d<|CS57jx z`%pX}ZEZEO^;$G{tgoA{2f62UgvM1lvc=pKokG{j~?i-_;x`d8=5u_rRuxr zbnMR))ESqZjy*<#(V|&%?>Ub|i!r`CP7(dghqI~K_dv0eTei|lL2?EdxO&mv=oaT% z-0JBQxiQj!4FqP8fevBhvu1;D3=a5IEB0ckcf=aAEoSNVT}%Qq49}RHdXvI_p>`h! zPC=l2I&iGnr5$C_01htqMfUlndD&L!oFS#*TDo-VloMfuwqrJy81avGPRd2c4GdiG z$mmi2*GCIhU(WZX7}@#fd^$pJuH2R0hzVIc0Q4I5OnmU%O|4e*C4?o6Inuq z!-$?mQWnoN_o55pY$%P91sPMqP}JJO9}*eMg;6p$7iaNBtH%>Get|zW$fQwmW z->iHShdEA1kS7zQaA8#Pu`XSYsAxSHiEv#-8T52pv!^$vwyWA*T*OP7Y{ zw8I7>j?tEx`!=p$Z^W@Yi^<<-U$2h)j!B+h>FCU^UJ5rXEpLt`Xxt~M>X8AYNkWNz^6=6y9SHo(yF=5rEU54W-rX_`cFRjO{i(E zMXpzuo_URy<$6_IEKPG(+uNy=J!jJ2oK@Bul!F)T9J@j=IvH%cv@krBo}Cc74+|-k zLSC+T{6~junh>F6G|(-g!QdI+DW62?B`remk$gdE|3sr9G}V;*knS#M>2D^uV>T+r z*W@~vBj=aoR7fTVpaaEp(819QNMKm-h2mMRPann}KHObQ#>su`S>M-b^ghi_(qH*0 znN@trg*NsaHV|B0I&}CJV-mM;Bdtd*n5sZxCGuok*H7M#MdBf5vp6y}8PtaD4i+^F z*_j+dK~`3y(Ww-Wx_Mr%T2>tYq=z~KFypW*=E-=FMqbiMiCR^C2#X!g z$s_O(kI{-G6~m0TjihLyEhQb zzwGj5WY%5&|3(8o`(o+E(#eFh!TlX_6CyYE|^kxsj1rM<+M<^W4 zAF_?%fxE!6h}8(^fpbw6vV;hzPqCcmYwg&+{h(;(Oid?(S00Tt7Ulz-sY~b17GhR_ z6m`e0U5Yn8Y42`Rj|h*nR*p-A5Vr?-D2J8D>S-!Ikk^}8n;j4$7Cy|LU|54!9pRk@PyG1O94I{;O1y;+eAmOg=7ghaUh_m&(iFwvTc z!H6Uml@C^b$webpdHr1_ z&rP~;5>!tXT@8p9u__^7rs|1@mkFOEek8I?Y;Y9!QCFjR4rjx`j4T=t6c8xA6| zqsvmB*gCzXBj6>1_0k{hIlb~!1_#*@T*Owq`fHu|43-Z_>OH3C8AmH?YX>lx%gx)^ z?A?~XE&qe!vmX9c#N+sBXWFN`{cIMs$FTA%tNp6vEpY=X!`XJnaHmX2Z$7B_>EnFL zl^zSH2Akg*NYmF44Q^{69mFT_LT$~nS{>S{`V06nqPHyfP@k@bhN;uKBgS6~*M05q zgg!YOLN_ZC6U`EgXybKaMBDw=C}d-fEGalZ*I-FRRjmYnBK4-W18&aC0{8wxeE|^j zu&d9`I}6K3crBbnzB2}iuj?RG^PWR@M5}6wkd^er413W>MHjK2tv$ADD+q zV=chxLSX*ZtKO8Xi>c-W8cv)my0G|!YV<;X0_-tzaQ_lI3HgddmLskO1qGZKtZIk0 zt-3VIYcwf2|28?S)tksY<}>$sw`UJg%baQ)+4uUs$hLkqL>bYGZxzkX;b|Bk%S%qOf&rbl)T zSn^#)aI4|%1}-1Zk#h5wuR1~V>E7Gci!N=9I&Ov(k>vGiJEY{eDX8^0&X4kKWeAnZ zA72|HF+@B%8MS9(d=Ku|x%4IbmWF&8FsNuT7C-Hx?&leYgkU1^X)7pWaa>LlaO~Uh zM_wiW#@CjxK+33FjD=lr@5VzoPZoP#Z>(!#V{Kl(1N!D5vNKZ;U5CnjwXfSf*_fz* ziTDD9d~Cf3H?knY0T?y&)auFqxt76eb>> z6Y5l!!cozRej~3(AB#PEcX8tcJjgmB2bsRUOn!)t73)q;VM#JYR>P6{SZs!PrXD>0 z0(M%=)h`^}9!<7+%9352C5rC+e9s#tG=v;txeJBKK@{`)DWnY^zYcRoCd;`DH*n{K zG(Z9qTxShjUhdh(M$dS~nO6h`eHpgV^VpW?{J9e+?iOJPU&H?l@ev2!agePwrTc0p zmoEdA*sk&qU8aJ38!eN#NF^hen(cDSs|+^VkQ7bZtQTV0f?SG>m9*$G!t$kA~Fx#W@BS;l}$*$Y$J#^(ntE5j!5N$@U5LyuJs zj=#5bbN9i5Wzzk9V>H;_s%_Dm6~%rtCtWWq{v;f`PP3A4ZvQ z89qjaFK|wyJ?uuEar*ghjpR{`swLfdf+6|!q?+=cGk*A?Py)B~lwj*eMVGvGbW!@X z zz2YO|54`le^d5X}Tpd5EBX`d}Gr!RbGcqdOgAaZ&JhXrRV2wLHc_*J8g87mD(Umo? zXC}BZT|(0%d&R>|&{SgZx7ahO7}>T)Qd_k$by||WeYI~wixW#n`=6mH|HS!YLukB$ zEhNJpO|>T$7~(0*i7f*QWogNr6%vgCp8>@S3>_U@_)kGr4a1|Yntf2b!aQ_jNi;q&1EgUFxT zbF;7=2(ogsne9x(vOwQ_D0`5`KK4||s{tD~ZF2WC85}b>V5TIg3A>kP8;~c=7_(h& zYKqh+=8f#s;;w-ou>l+;=VRQ8RaAa5SE3vyvXJ!WXl$$I&1IyqCWxSWtDuCU@f-i7 zXV}Gw*)Xga6>3lQP2-TL#Id>Jh6yL-lDX60-{s|?uV{l5?p#xonB$atahpI7s2eEI ziyi_E?8(*)j3l@YkXQu|Q~5Hfw&-X&xozIf*4~%J5LAqA7A-nnQ2GB)+TNGJp2ccl$jp7RNBBR zz?nE*@^dG{)6{R1ZXM}iXeiD2Nlu1ah^3m6hizyvK{8`n}DxthYEZO78Uu~i1 z3Nox_NH;NucJ|^d*PFXO?a4l+7|E0d3`92cQ|am+CY611Z`m_@!b>>-TmgbpThetI zXfkC>Rb)c3hBH&z7cBI=hN4OoQeXIOdaFbOf72*4GZPI@7@nJl=mL~bC#Z*vBtblW zGvSu$;-_JhezWPzlx88IM)%B7{R{9S1E<<0k6=aFHH#iAdZzjkh@_oLibc*fnVd&5 z(d~RMcWPZ3*#wR=HvHL_T9cEVj?LozMwes7bw;)xr0$))oqL`x6Q|=))CFM>((wf$ z{&)XXq%UriG$qpeLEp(TeFK9+9@;JRB%w$n=!U;v+#u_gCFs5Q^sxKvSd2P2R8I8fA`AmEhB+`pZtNfCnI22I3eWJP}1Ma9+l&)t# z`ciypD8ysDetW##dZ2`QF++1{+RVm&tCH2*WkzL9-b`^hfv|b8^y#*I-_;-3+0t^{ zOLe>YRaI3R@+Tucb}plDZ!b51%-lLXC^U4J>B@!oH7j*Dm(cap24zvGfFj)Nlv&_X z$8g_FFtKdPwA`O*1*18zZ_SHcNv+v&eaLz?;WVg?rt0LAeSUtk%bwoO{Oq&OdZCrs z%t)ooi)m?nY1Q6*+4?@~xD9EAF%BHA=$cxiof=MYwOOzthE^_IqWki$+m^4$0j@KH zdKRktYp6P$A&f6@3>~SCyqo^$eilGqo}QX!Wd zN|t_TW^&VRHQX_wB(Z6`c2nov@0P#)?=KSNp{;bNz1Nb&D#<-r>!Kw+cCJx%0xFyAoJf{2oD+WxNJuA4ZC+*Fula$& zedp&reRy9z+clX~fOMaMLXmfVXub9rIO z@gZUaC56{PP>sd4SC~tBdFymhs}~L2Lf}DFefft4EK%2?8Y+g&c+Uv%2$b zZnBl78=wAQyII`<^^1Y9?N;P^%@`8k(EHwBBKJGnRZ>Az5oM{}qU_we_&TCDK|3YW zPBb?Bf|hSfuK>>PNFPr{!?r-SBsYKDwP|(@*?YjeVUK&sFFuI4Sh~j|z|+G_d1v`g zyno-i&w2Z=dYzi`&;Na}zwX`S@Ba=(@c-Vcst5HOHf&qm=aYiv+SJ2F4~_ip`=9;? DJBCO- literal 0 HcmV?d00001 diff --git a/docs/docs/assets/images/llama3/lingua-70b-mxfp8-multinode.png b/docs/docs/assets/images/llama3/lingua-70b-mxfp8-multinode.png new file mode 100644 index 0000000000000000000000000000000000000000..e8214ccfeaf492c311caa8034ab1fe1b8348b745 GIT binary patch literal 232840 zcmeFZX*`u{|2Mo!Dnv;lNdr=u6Cn~Rq@*I5NkWvlOd%ojluBADk!Ubh2${*yqR7x- zo{Ef_#q&M(zMlKJum6kZ?fvH7pKI@HZ!K%B^Ei*=_xn!2Gf-1Qg_)6ykwT#`tEnD7 zO`)vfrcmhju3m|s>`T5eh5wUvI&#kGjO|6IYo-q7loO^-b~d(7Hdg1iT{U-bw6eV- zAtEj&A}zGd(#grrQC3v+^8b2;h^>Q#XtBOC4e!EWr>g5np|A&#zn7IYA9_q#Mxm%3 zR@8E#4*hUG$6m9vYIHjOK;mJ=qLPxLLqeK|6px6DW#r|h{s{eXgR^;~SkcO%L-gfm zHO1G_v9l)jP2VbXd@q!@iYDXzczVK(C(ivL{`!%1;6PrM1n>C}@pVe8xBLF@{|Hzf z9VAQlUq8YtA{00zrPu$@*KRn<9QHq7xz6v@P2T^0%hlTtn>Kc+{*Mn@eTwt{zUBY7 zCjWnHc{+gl@Zt9Lfznv>zs1wj)3PCn`Q^)( zk5f_@=;-LQQclyC-dbB3x_Q6D$L2>4R~B~9E;RW0`Hh85ywuGKlGR^wb-lx?r>7?- zA#wBOO}drLT)q(z>%=M&lLY=8s2Shy_Q$gLkI{RNz?-Yqo?n<5v;3YT{`&Q6#vZx3 zKR?;drW!C@yLRoIuCAiGI@7y%?{?|G72?`w=`C(nu}w^Dot2f9!!u^efvZCuy^fV( zTNs~CW>>p<;7VNWFQnXxP`F)RptEqTv&7G!IW#oXFEcZ<<6VLK^xZPEka(+xqUeCrh zAy@I~)2Eod=8UoXtd=`EI$B=3q@=CQ?(grPo|SboJbW$77Fplg+FH4Ws9U#gtr8g7 zzI{7=tf$72Q{h=}-}c$mef|3G&}YL!PkDhK@yfS0s=gx|Jd*i8TXlm^a`G;%G@~uo z=cmdu%qmBJeGyDdO6nSHPIH+V-QM&{|JK~(V1jM?d&a^DQK8KIvYm-ye@}$A?B2aw zRaLbJHzVD3;^J$Y*7eL>`wn?{EZDZbfAdCC*7=ujZ_T~#&rj7=)zqr4#R{H%c_`PR zSL|XfCso#&G0mtz>DV#G_g+h3EiLB^3f$NQby5%KUH!d|jqMD@U2%0}<<~dXj)w0& z|GHf;H}Ab4rpi|`_4`|VLP9A+d!{*)#zR^D91Q~l?uw92rA+Owt=K|BLnpppG*#2k zsCc21X16e7AHprOzPI{L>HF*RE4nXUxL{Hiz?y7SAiGgm&o`JubjQJiJSizDA3l62 zx+iM9Zp)TXdH1QJZ{IdcUTQ45EpU>N?@?v=*2=2e0{lyR~;^| ztHLhB9`SN9>7VLof*H&lb-2UupcHU$* zIx(S)%e>$=rm}kWfI;3><}I>LAKpAOvoGT;^kdoDSAjNFmoO&lW1WSD|cE4>{Vq&64Z=krp*W&CgHYw*{jUG>DjT7SIZ`gNN zs4YIce?Jg=mEAOq6v?HHcb?EFxTbOA<_t3xt>;>SY-gVyyR)PEqgO-7CW#3AWljB- zZEN=P^HXa2bL|&T@UhhNkdQThhFWj$*Sd!-DeE-gT~>DZrD5JiQRDZ~LQE_y6|s`m zVwW0|J%$;!CrY0yaFaZK{5XZO^`J))?co7dR@UM`$KGmovN~a5VXjFQHZ~!W)=fQJ zl}$}j&Arc`KPQ_%DQTyO@q22&l;6zU-1)(#S5z@}c6Jvxw61TE#>f+@0G7H51c&vESuIB?Q_26~A z^;2_`)*)PyK@yE)6``B$r-pQC@1CmPtN8Whh51(vY$=q45QFZPZ{KR7n>vbbcJ}nF z=+GId(9gbNid#TE#m{NRDSm#1-dh_6d3kv?gZS~Lwl*GYFW;~*mQ;h>WL*04JYMh3r6UhRmtJ4|7~8w4f3&kS zUo3um>C{MjoqJVVs^I}{8HdB;BO-?H9v{22hbr1>@iEp<@#acJWo6|sDw|Re@5z%V zX&zdC=cjMIdL`V|)g?W)W5&pfa%w-Uo9DCx0IN} z?CtII&(Ldh%qQlDhuMlWSk;z3FQFp@V+l<+Y^r zZ70@M$uoYtHUo~tj-53I=BB2m`O;YzKQ5?VlqH|wU?$Yj_|()ehu&&a)XUSi+-ferk85~#G69ul zm*Km8TH4z5#wHgpUhKkt9UU7}EDvIjPfJ@XDJdBf6H`=NysW!2yb>)ou}`;YVnc*& zY7NS0f^nhet}{t%8ozuwm3cu8r?jgqkc~2~V9|+_MXo(s#89u#q3G+^*erQ_y8AiJ z$H*H$ed@QK|JY;Qtcxvq*U}naGd|~D*Kaj7)SB=7KN@dpxUE3{k{K^0T2O}tt$p@uY6RBM#l@wlwss8{7gv-+ z@5cxGPd5jKgp}bl)xXdYZA>|{N=iyfY~McLpdiLSlar_7mFS%Y>X#7&ASulY&S>N5=Z9m$fNE^qS*!5+xO_$vTq@o8H zs$$wIZMO{|B>=zDF+V**HK$v-#yc-BulCz68xawaWSuk~(FHW+s3nd)g6J|I=eNBy zE%ncLnf>{68ilv5a@&u*YjutVKw+=1{NTd*tFN#320Q?!xHT|fNtqR-%M={2V6}3N z7@Ez<$mo34KQbcPQ1Bk-o$9q!&PA+UxsAZ|!z)7wsQ0U=stR1zvCzBEvG2GMwZ2|c zgZ#sYO1jzPuP?zIB8z9Vwarz+xFch;QGvwx`T2{TkL+N5Amf;1GKY_F=r{g)u$C%w zna^tRO!Chx@RXJjgPZ# z-Yl5&3Qc3QWO8y+X7K2dBN=u{GC@H>4Fmk+Du7PXix+~DX}GIq z3SNsnPqj6B57Fd$C-428yVsvBsOs9Q{t|DK?{aJH*1mGKuWJPbUo5XltA8U>V_!Y{ zP3x839r215bjwyhCXSZQ(p#;-Jgq3%Ou~;)r%fLoxOUg=;J~k6nU4K5hNR}_+Jg7% zUSyirWTN#mC1pAGYb2puIqLoV`7?jtdym`5OrdycQ3`wj6NiTAHX+$lsUQ-tLD@7Mo3K zzGyAV2%-@nggVUCWA3j+pq@PEJl+Z$_+of_1D zmoJ5ifd*sb-KD~KHuh1WzzMxy08$4@yOKX$C^qizx~84&w;S6iXiR&rl`bB%HTjDj*AS8B(2e@zI>ouMrna^WpApm(JBYoPT5#Jz zu#`i@AOPFDs`0@9ioIq1W5xK75-Ev^zGTfyhT99n4_q4_?P|-r7B1n%L&3+cI(Dqk z=mEF19cA(e6Qb<&LP3=M+&puFyqOnoV3+ZKZEV^;zkBD1tH@G#K3 zU^EYH_D+@%9R)C{YU#y`7ZYx^nKussJ>sJZin9qu+6JiuTs;$|9VPz!*0%fw}{{{e(TtY z6Mj)q>)ob?7z6|a@;zqPqe%DcbGbe{PLK7dEV@C*Z`xkyRf$II`cwPmIo6vinU&CB z>=qYXhtTmlJ7p%6)zlciywFj>tv>}M620rxa&!ZNjtGzdq3%X2ikiXcue?FG%0*d# zJsJ$P8>>qd+q!jYvFV4l0{2QLiM@N*AJ^GsaCiGJY&g>^SHf_;#h*W~2R%Ipc8Llm zW1n>VzK^i}+n)PzF}H8u?xA-`ca zT?)c&HEBpXp?c!P9U&YLxs4(QAHHS#0QUIdQpUNaK7RZt?lEHrWWfbe(UR|)Rnn|? z{`?cAUhj2dqj9MFH?y<%N8h=F`~6|Fh8?i`K*O`}4m*e0Hy1y0@d*fY17GHE_eT{E zL&3Hy>Dk0GdaZhJEQ9EA>t`Ecl9NBl}J10ty zB0&l!Cf?D}8}hHuZ-nB|y)V7IveNAD+$7aiP+3XIW$PUmSJ&c?ADMASRgWDj$+D=k zo1eN2-V}PKEg*EOLfo{ynja&3J-*91J-rO-jr}xtfuJaT@@App~`cyhmcenm<4ia#EsrO)my`% zk&mXzcqJ?SmKHl#-`>2f3ZA3>Ruf}mRnA+Em|wh@i9YJ^mVbR<3wRcN_sEaDd)eWT zRW^pq54B1d889<5YjDS#y@2pxRNz(~cnxf{_%(Eq=k=cb#|6{2(|pHM(13RE^2&_i zdsYSmLe&NG?n~S=YQ5`Zd_2RZdC3b_w3()+ru@GOkRe1NFY&13(3I`|uwDCPoLp6K z!iyI*^M8G54@yK&Yy|+T zw;r<&G}n90o*%~VVf&x)WZ!gp5(kE{mw}?6c_EOhqUClF*ef_7prgzka*}O}9AyZ{ zhU((_Lw~P?gf8}`wAH>#jR9;Hiwm>V-Y?I!Z^yQUZ&hfCg*v^H#+1c3)Kycc+XW&Xh z-~+YLE!G2E)u5B?JlEDEnfN^B*|TRIlD1Y>_xmR+ViuHh=@#M}%YA7RfGJtfdNRr1 zuk^Dsv7Z@OY_HhZ2uc63zIbux#PqZUj%jY673j$*Hpy5P8YefYbK5fvvy5$AQdvbl!6Jy+jsA79R8krlYwpL<5LMo#rE&Ng`XVW z!HWMu{d6gP4dp0cVgz`6=v(&5U+qv8E>x?Vm~1+0qKe~opva;#~xQ>x=+!-r% znURT!X<%UB5IW|L$kjlvE^~i&TRfyrZ^T)q6+V)4^#T3rghZ9k(}XQz*_^ti-;I%h zp_6%;5w_$CdV2feQ~k7u-IGxTj{Pic`L4W9=NSDoH^u2}0V&>^-;JK@T2@gpK3u9) zaNE3Kzu=Dk+}YtD2Q)Tq;o%9($;mNju9}@+MNe-EUCITJ&bGydjg4)eZCn1fOp*OK zbQ){Z(GgJhQAF29rN_y;_gUvOq!|}7vZymH$i7{>^`MsYZ=X!mAgAA7*-_q)V5?-DIMwE zC{kk=;5^vGz9u_6yZo6ef8?&rnl3mwP=`o z?oK$euR@EsyDynEdmPD@x0g5fmrtSqivc+deSITX_Xmnd1~fE-ZcA(He$j^953Meu zQs4_!=|L-3u55w!EMk=ZA$Mot^?83lX{ZW42e#(k@j`R!_U#o~>Z3gj#wGKD2lf5# z&CiDTQG&`cULq?j%*?j)c!bdY{rgn_jz2wCq2&=f4n*oY{T6&>%>U=zowBle?~hj(H#D%JK0+EU;Il!Ci>io( z_vP+oNwTeW^#kbNbcD013(Ot%B>% z=CrLBY1G}b*RQnYSLq%^Q}lR>?qNwwFKjixaN#kqa*p%h-P+Ql!3@luaSw4G41V@j z)Yq>Ayih?cknYsg)rE|ygksEh{Kye{=;lN5&E7nB zn!a7~She*ubsBt?9rwQs^o-vWB(NpNo;Nx=+HRnJyT$^*(%jPEMR;S zbRvhRmoHt?vowE*E)xEBfO=U9(KDNXFDje0lD6q^(oZ$3!~Bnr@BdjOk%*wOmA>I40n z0ZP7Hs}l=)YTDQiam8OwdiHFh#DPV7=LJB-{B&4=CSYMLVBJ&0KZ4PlOIli_>t2tm zCtvH15jS%%dm z_3G6{xVU<#0@0$zam^p1A|r3$Z8!cB%W#cs3Kmx=*o>wkZc;=sn!yWQDotj`d-EeX zJO!YU(AIscs;Y7U^$y#r5XQ=TboE_ric7*aN{PJZ3+)#b>byI4D8=V9UOzrIJ7JMB z=GY_(tJ5X4>p0Y0S5s)!eSS}2FoION zCOwOfFJ@wMD_XmDZJp%9xR|GAa7&;Xsh73GNwd?im#=Nx}98J_&%mCSX0k?lqqw(x2FGo zTYFJc(_N0M;1I<*N|mDXlBYEL9cfPGzNwN{vvWT_H#D$r+Em5l^=E)J?OFl^_I^JL zMV|eST?*2LRFz6ykGZQEP@%*!urv!adJFBVJ+H_toMJh94H_UV69)}wWHWJV!Px_e zv7&xYWbZU1N8n-NHKT|4L3RF~cWtejnpzGFIhZ3=fe#A$shamYMn)bD{@pWgJknNR zUTu6T1MaQlp?#OXYW4jU9NPXJ25JD_+_41O_f4Ex#}e)>TO1$i=qM}G85tRkPEMAC zp^g6ec?L*_j_^sz6}>|;vj@Js)V-}E5b3PyYDHK#imiPKF{X@-jWxPuM_=R*;@u*d zpQY@I-z#GguGFE!YfPK9V&7t~oXjee?XC=v?qaIcYoer7vh{ELNzcLMN}+E*yH&H{ zI3}8C{*#Uv-`d^PRl+neIoVsVPA}_X=<(7A49xVd<2`D$LU5W}IC_1tVVZR-Pn zwk{lgf@t)1tt>*FL=|w~Zd-{w0kRw~xhslKpGFU{?UM3f^8L>Uix9rY^zHri1O4qFc4go;GNya>?C~kM zxxpU7Y@7NdT68C|-f~(MIzH&ey+}zJb2EDdGZY)R%6T4}UPV%am)C4E&aX6 zb`dHAWE%Y?C?rACkCAiT zi&O)@DYOM8=op{Uj1>k32YL7H+W_||YIgw$#YD-~-(Ts0boW^|tvUJVU^#f*gbm+5 zaq%0Gk*xi}!w|uUOFuk3JhS-sfRB&QI_RCww`Gm%NBVyM4n<|{0=$b#kBQ+VrV(*TP&&Z5 zcxOsHzvtLf{V#w(AS>YLW+&%5z9Ka>b>LgJC?rlFkW3~i?%tlhK7YtzUC`a5(s9l4 z0Op)~FRTX2UB}9L>Lu5DTxJId7cPhhdPLH_H&&E(uc?(?OJ>{7KcBB+CnAs$@kfIY zJg9HSfBwA!LF0Ch1g?&#duZG%`d2vF+3ot1e__|6I}{8wgcyoz=kv%$gOFO@l3ttoHyYSqL8paGA%1BTOFqZD~Q@@jr>Lk z1Uc2kE%3!4ERkDu|Nebx$G$bo>F6fDdx8m*vV$ZK3`ycm6Fczrn>UeEY6u|sdGIEx zb7tK~Yg5x3xO{uAOIzjLR^U^detxFcey#=Ruq+BK4{wjmib1{$D>5hp_|jssfxr_{Uit?@&I8SY&E$?zi(q?Bml8vZ}p`NXqo2$VyKS$ap+IP?o<3< z>Q%&!-=_NO0J0uGeOmrg&8@F?6ZR@Sw0OUu>NK5L^aUSY;T}1E`pGdSLfRaP(UaGm zwxm+2oD`a|hJ*#Jr<{SAxwnk`qx?QhvJNT73P%TruJTrZCegV1NB8e<5*3n?x_i$E zaB%#%x}l+b-Srr0`}1)>A*{yU*>%c??IO6`x3xQN0(?xom9gtMl>0d#B;*3v(cUd) zl-eIZ9zy9F|I8}8|IeR4j;~D6M zYoE-H9X?jp)-@ir%fu`!ET$&flQxA%?K>jJ@CL>ca2E3H3fgdm-ADfN;En;jYWp)`JG&r8x3E2g9H} zcPEZP<2wu4Be#Kd&noj=8q;&05^&@gO}C&pM?cdK5!+8shyiJ(44!x>t37;O)_uw* z{`akV`>BbE+k(vT?OFo{gZ1%rqH*``-7|>=QRyxh)=A}_M{N^`tCtjUVyRIYkh?d1 z?m{zU++43gWCb#@+jixD%ysN{_?-3nb*BdU6> zqe43be0@Wp{SlR}qALQiRuXMd*$Bv!Lyug5{5KaHC zqqH=o1!*h7wy)H+wHb}7k5Vtz#XTHsu7tZvPg7Dk>?1u`w}rZlgW$e;H6^Hhd3j31HiX2!%Db18w}2B_EJV)=tSf<|9Eux#0h42mw&sbV8P6hnuJ{8 zc*r2iE^ao*xe-v3QE#Pbe(^0^wm`-itG$ww;IEbl*S~ z)(@>W_w7gKW$(Sq<;7Ti5`)Vf{sOB-w2a$cZS!8qH1l!Ivu7)w90^`=t7)-e2+Hn$ z8&Hgl#(Is8n8?9_%}6V>kAWUNnnDGm%Y#W;vF>N$H6%yZaY|J0&Z#jr%yM~zpC#F2X9yQ?*NX-mAVildHL%GV7m%H ztWU4>_mdzf|FLLqtRIXO6HChg_`v5oif%~5U1QmLP{^eKfs*rJC^_!aA{F6V*Wzp< z^m7Qu3pW)#QW=<>+#4v#YsbdMfcG43aF@25!?Ftq4lad1`zh(f=BoNE2v7CR%_?BI zpO7qkc`jo)?%^RVtw!6L2L%Q4VCCgZ*$4rAEWZCkVQJweQXaLB6t?2lA0@G%7cZF1 z&CQcjQi>rGVa>`Qk&GdUr||s7iP)JLmo-}tQaRp3^i0_*1(;w8B$V$y1av8*?7z{1 zDAU0H2n7WLaglo-)#-?i{WvWx9G@Hy%r5QFv-(Q=duDz8ZyuMk5QfUn&32^SzyA<9 zRJ;uRScFOlRu)+g&oGiydhjgyme zgAxllm{AV{GPos-348LTARMa*T~EeQ9?n-cJgg%zf=Y`AL&Qni8Hh8=H|6>BF4QR8 zzum8Y>*?!v!*7i8(%04PM78vLD^K1=_1es+snbx)sqvA$=G7hegT2$QH*^gRg+Pwf zXlY;M8buex?dh)NcY`5mph94Pl4|jcgaX@6@^A0q|D6Xfnwgt?(J>Aa1l!ubp=C7% z9b{okD{oI)rsby{`RLid#^L+hacy4q7|Z7OwS~msBf6eq<0qee72DHJlD)7kKba4# zeS?yyVP=L8Ov#9Zcr$w<`Cn&y>z}{#*N`Vs0#FBxx3jbJ&drs9AvCry>xfk13bePL zmkH0G2f>uM-^b7AyAX~f0ZcDO4h|0ZMXRAwt_wW;k4UlQfIsBK@g^DUQHc0@$+J&* zNQz%|%wecS{IylX(eV-NZ=zas{#mq^=SHr^)XeNAb_u1>f2QF2Je-o`Q)M!|_@%|o zTK}0A*e#2F1104U{))42exH47eusfe6_vrLt9CY1A-5E3Fgg1b>3IDPoGT4&tzwfcI3izlS;)b;Qln?;XIt~tm zd=72{<_DN1Eh+f%#o^Ejja)_fFF+&5^jQ3a5$e; zl|sDH5A+ne&+NCiR(KhJV*A?ypV6*zrhvmcqz@dx%Yxn@$Ee7+e!Gy+%jB%nJy(ZX zmPM-i;)l51AC_Q0CA24Uvfg_EwM68zN5J7H3d3<)-ves$U|4N%KI`HJooU-N|JdJG z>C>m$CdN*>^E{s0aB9QD{ke`;ull1lL1Yh-6-)4)aVu8o9&Un-)7lEgF&5<8#@qwfoH z^Gj-FTk}8635Xmh-%9Dzz8lm~*BTJ#m2fVgCL)Tg2xL{4-l3%Ez5f}q{xS!NzA-@169H6 zy&VhZZYL0OC}4ZMjn|$B9x4Jb43^gMU5J8I0SkQ@1>2W|sfsxid!~1ZOrI`!I^aCs zS-OfyCO84fx>-AMR1w`^hDEayCU@C?!c8T#gWJ6x;xcvsEkntJ`F8gw|5ey7@V_Zg z-_M;pR|NIHqM~AEw3CkD@$_^tjMt3s?}>{jL>F)G=;-hUK#2!SK>GRyf|mqu;9AZT zkYizdB<>yohq#@IPoKuauqLT**oF4+eL!ut^YQsM3_KplT6%O5p}(2&UM2w{j5nP+ zbqX9rm?+Y^XX_KxF#o^^{%#VoefMq#gz{E1GnZo%iNW`LX;Z}HA6y+gSqpC@PW?I+n90fS5U zTF2GY0-cAy>)N>gZFu>oE4vH$$}Gb3-_`mS+WU0s@H zc0vH;H1-eI)jYfehptCYdn|0Jqoa3Km}mQSruCtMSDd`(LJUbFf%Q~HUmMH)3=T`Q zmbUSd_dH8yy*|61*-%!OjO#h_@J{S)Ks7F14-+RbVblq>lp4$yY7~hymLoL zH)@&5Di-bzAtq*KbBGDG$cpd_2&k8^U!M}S>~ie7w!YobF;6FGu-nzJu4`VO=BTG9 zZC2=+1{Q@vDw&j+SVp+UlEwz!@|3XHt?){YpF9~ZBQ1>x9gD#K( zOss|2C3;f&>(^z;rxGGyuv#w?8xJCOG2)A$(PU@>>MZZfodsYsa)7i_4K_imA|x2b z+oxPdi|_faXBKbXChlqk)Jt$#Mv-Q8!ZjnGpE6N8K>VnqBt%g z&>v{cY9&yM2~8Ci76!$7l%;_>@+sX^`LK#gcZnaCu1@WMj=Z_v(5l?UnMZt$MNwbo?iQ4|Jo4o6stnG)!48;5E5}Oh#zdz zmai)|9LOf{k43ph``F(6%TV0K9s6pRw_9M?VV~9K73k~qaaZk~U~s&?^o0@s2@a%~ za5Qw?`t>`U2lm_NZee4)3B}+&u^NnbqAKP%{fy#^b2*oS-|9vuwrqbd|F4cqziQPgP@qK3ae0zAp6!7|LFL6Yt=P|tklGPVOTeL6EUGeW5GNl0#&K{0YL6y@Vvftu@s zH(#@64Ou+tt*4Zfmcu0rhj69p-zfp;bMGDpHtf!k)c!pPv_PMYSn~|&9zT5V?NA*# zrom4*Wu10mux2IXe@jct$LGs~e5M%y14tU1A_h&ACU#Zk);4?!#`RVqPw_Z4l?8y9 z%#C0_vekkEKN5Eu&{W%!f?tN+`SA_QR2~s9i24#0Ws|ml{rdIkIFZr2&#Xiz@yA^d z4U=o`25{-MMJ|aXo>!}Z?j>L zq%`w8Jgu0Mot+*92uMJ+ zdb_Qy?cpOw#=0&(N}B?Bva`3p0g?!+Ru9ebyd-@k)@2z5iOzg;U?Mt7TdwHH#Aj#( ztXsE+1Lc-}`}QEKwzYK=4-XF%80ABUy!i|wExb;PVLVD4d7hVgZx|8O;goUUgUe{H zr{lp4@Nph#q~z028HxUetT(hbrYYY7BoxqKS3$$&+`JiBZ3`kUock=-k=YGrq|H0Y zHm?2nF+M42C1#ZoXDe!NH%`irlXG29Mv%0hZy=h=p}BRBf{Yp%7#KJ)VRY=S5V{*P z5j@BYu#L^j@5+J0(2r2wZqI3G%IWk1^^m(lR^%5wDKZF4+CJ>)F!?k5V>3)spVU+l zB1BqP+`?k3{{5)nG2@Nzj1py#c@jJMZeAOS545zjP-Z6eFQ#qo&9{X@vjI*VLDInP zd$f|h0TOpR{UzfhkY$Q-rSvKyRj)i*H*Q=BvmE21tB@FSz48JQB{Jvbxa++%f$e^Q z{#yLypFzCHcO;oR(2f!~L*sYaph~f>U%woL2y3woE~op{P%)qrncAc&VW=YQy6EO_7(q(FBlYxOjJ+y#byLbN+6M@DIa`vm~)q^xM zCwNDcL*RP7dZI}FP)L&VelzRVks;+31Bf6&TZ1Hfn2ern+Qa|}g;E4zpNQVl($XfF zJ#*+i`2p(OC^8{f<|15El5^Pv3&k7Z5WWR`TQYL+{rh*szbf_XN=xa$>?&p_`uXq3 zA8BHJVO)3+k|S{iHcd~}jlxvfXVbC{^^^>B0`Z3c6BYv)6X^nV#10V-fbAj(6&Q^q zA*Fa&uF@;)knZOzAf*_n6~Mm?mp+PMppgFCfkEU0Y#J7?>d*w`*`j(_Id zK~Qm=muSjkI=LRRd+|L-0e~PLtwf=(K(4+XvL)Kko@Qs2P%b~*Y{L66sPeSu9p;CM z|22laQJaY%3FT87L2Df6QH;USG^>iztZC)t^iY(@GYy_=rTBomA|obI=h4#(xh_V` zbU7LDM-YbOtT6Minvt;tE>k)P6y!xTW$YEb*OoUBfhG$+P#+IqU52lparL(z8I=H7 zC`t<#M>5NPqVEK`1xU|e{aj|p`8BB0H?Yy3K-!u8dTZ%9#(qf5P(Xuv2zQP^bDU>h zOb1|;rUt6%3%p0hD9QK+EJ#Aufa?xnxk-i;0_#F{uknD|tPtCd_gY%i z)YkTcohl?GL}nm>SLiVXsAy@K+`)7osT*>nP$5ZH7&a}PTEYlem~;aL&ITQ;eOeR19d!#)Vq={6Aoq!>hkcs1& zfDs0To?A$}LmOIUC$9$2KYGs-8%jMiMg zejTCvjf76+T^q5R{H2YgW@z)ae@+4}9#k9moQ`l}K;QHD*5o{4koq9zbp^qy>Vu(v0L`i&j&c&|k=ZK#L+R6|}G%xS5#gB>xAOlG;zK@bd^p5Ucyy zGXeDAb(=N?0$^=JK>#K1^{CxuyJaVyngHFi7+k4ycsL9&cRQZe0oAtz;g?UikqYng z?;pG_3y84R+S;1vjxeQ=NtGvf6Jn?t^6ybOSPEq<`Pj^u8TRvI+&|TmC&`Qn@k+_y zKhh$-i3*UGFx-#O9w*MfAKE9VZ6JF0YAFqNU;HQ@nw941(?0mVXx)6w7l8CD5n%85 z^{ZNc1QW(MM?QIZ2S~d|IoKvG&50{4MF|cn7yHq*cMdPfpWX$yL{C8tmNWhq5FWm2 z)Y?d5d)Fx3@yFmqu1NyFTOVzBxr}b5NwE(->7wV(#TZO~B@YIHjDT{lgjr;wrefTY z0XU@+Y+UonVFon=P>rGQxok8V4UE$V`W#9mDVp49J`foQH4yj3$ z+UQnAd?UROtLpHvz3<}#GB<>$FZ^?3k$=4aFEaHf?qA8LPplxpWD*y8Lnw+ax|B+M z_;A_C3de&>7X(t%yf7C}z9B&9^5x5k%5nJc;SL-{>I{%t5l#}>7G&>%XAz$Vsoy}j zEmV2;yN;pP7Z-HLhqItVkfa5v!^q5_k!hv=wZD3`0yR7xMG?yqgq(w(dl+q;oI~n?Yb*+$3*^}#9mB&cuwRRv z3lWS#b9sWU#i`(#Z?ejlT2}Z4dxNh*e%)#Ag)O$bJ{a>O`;Wl|f_tI3=xnnx%5|WJ z)pz@Q4`$VXjwDD#1oBc;S_h{Y>7>B!q*!9y>UdNIen2=58B)TAxc%J@N9PuV4Dbss z0?vE-#Jv_hV6Pod=jg?$6odRo97EV}H&7Taw4!Mboe^S2DB#3Y%W?8?D_-@GYJ9zySjbmFr^F@6ZX2?*bl<^)tbBq99*AZy7)>)2&yFL$JwwE+)ofYrvz&hCc>qLwi9NCf&5ygK4#@f$zGPifMqzq_Gbimb`k1c(``|Gi!h$@Q3{*&d(ao69rGUNz zt%CQ4+}5Cp@3=}ridM}SBO%U;7HckcAj;0cmTLMWdJ(X1$E&kB-B7`$t*VKAm3 zj)G~5;gJzeBW+tyrYWA14$ea@tiYP*vaf7`qFk*jDlQ&*h8D^xzI?xJ+hMph2SP^YsK-I;yn}-+B@yfj~WJJa*m@tzZW8 zEEn3WGEju8VNLI~Gi-(D7E_*USK z2OTY5#s@OPuyrtLTsNr3hF#*%y!kxl&FPyuIPqK@><0+g%P1JITn#Ks1)N75 zwJ1^THlIBNYev-hw)ZC<9o#a}_py}hsIl_iqX3|=$mwf%pl*TiZvtK-lyoEkc|_?p zS?G-eP|l<-f8~?ek1rMq&DIQRM$A5|sN>EPU1z=48;27k07t_1wKv!%DO!84$K;q* zvcRp}v2B~Tjs~VS9Ni3uG4BI)`34F-H8!=@P#70TJ35JVo7ITLVU0Pb2 z=&^eup#cG_qlI+WQwZU7ad+pa4p>Q|oFp-D@uELafy2S|NY@=DAx9Dk!OV|aHy#Vn zfrKDPT%_*_=$)v6B>=M*JRfD^3+K9DuGXt0gBPS0LCPnSxuk8tj>mIX>;Na$aP12? zGqqc3fkpYNM>Q%xg-F3w(Q&UWmX@ACtm>K5@9#H^hQk7B9COxJ&`M9Fy8lKgcBuW48?3G56|Dm+9dx zIMKw|sM=+TO`?o4YX{GI2&$=sKtI$Z*Gf?o^y_nfiV;m9Qnki{7apzwJzp7RPbR?) z>vd%Kx&vG;1RanzqQ|Byo*;s63{*&f6+#I#W)GlF1MIIM`A-k-$a{bh+a)B}K!_`m zbSKS<=76$Au?9~%u?Fi_tp}3EBEHewqfR_XT7qV35?;I)^a|uOE zMWs@%JE*dYV+paMaQW+VT78!T`|XJ2rhf7yM@YjPYn`e^EXJK>Cg^l~ZEt2Q{hd2QglZG6a~hx7YYo zp&70(W)5$WKZ{1XHwFd+9eFHMrM`|FM-z6o%TMr_4D_#wCf2$vNo26pVi4KD$nbe= z|7wZCu1wXq!u-BS>_xkLr=RQTiclSSbH0E1qHUMDd+bdf9(R_T4W8w&(t8#|5G7^3 z?+OYEI%KoM;+UOLKDIIw&zX1`ZmdroB^~zu*ypP1>Zn>X zsFd1jIWvnpUxmCvNQABk&w<*}u!v!>y_6yBPpVz~g!ulhxd})IE;VJ2EN5SyJC*br ziPJI)d6eux-$T@GeWj`OY000LwbNrU!%=Iq(Kje+bb zr2+jOQBpcKrnmcBc0Y$*NFJtvV%Wzug3rPIxOR=hfCqCDlyqz+wGuM9NlC{Jo#?_8 zJ{6lo{Lq*Tr(d7lhDXYnfnzhU0Diy;Vqv-K8(8xWZ!*zUSyra{YxQ`adYpnQe5wp^ zkG*&pidq89aOz%z*!;MAa2jcN1kxr-n};`uX~e>n%;>CK`FRfiu7kG{UV4q!cEQ8*K{{^MxXNvN15w^RuD*|Na(Wi*_n(tgznY;~Yx_h$_S$l=mH8d_F7na+W?Ff_PQKBbA ziYj#F8Yq<{m)Z(9+rij?#Lxmg=6k-Bl2=A&;g0CZfE2_1bv&JE+V~>XYa_a`Hpg84<1gqKZQKh zjs$VS8j|;<7G& zmSIr~LT?ZAsl=Br*Fb-QMt^kFT58Ph3=VqV+3!lqJnF3L4-@ROVsrXeNWmB9?f^?{v^oQaXcrx46m-xWF$gnZCJN>Q7 zTI7UJXrL`$EEs=s9F zfR2Jd)w=KDAR}Y`qR~eUJeinzcdZX3^8u@Zf3Gp`0h7xQ9LdS$6*ZCh0j-Zjf0@#L2&QJ?<%;CAC9V{CUAY2i{llQ z>30)W#i%*L1r-odGOU|Lh?GlEI2s?Jg9NxjJf-sd9RJ+ZQ~`ASslfdj!g{QD%$yP~ z4G)J3E-E@ivX?;Vu10j_W37eBcnVano98)w{rulAk=;5%ad*eItbog5DJ|41&Cl`Hs=Pmo%e$RqszK@w&HM zgP`-P8FrvokYO7noj#xAcO2jU;QM_a$NM<0 z_w}Y;ujljeILH0G-|uH_SAxTrPmvSQ;>7zH2lrBHvA4J5Wba7zBEFrXSC{eq2U<0l z_Iq_-;R7-`h;_9h?}K3G5`uk;7-{ZNPdc-9=%vEU=mr&IqO(C`WWogIqF-{%$kqYvNQm3l3=lDeDS$l@?Y!WIhvy~BWxSc`@~xPNvs;OX;2%h zoxg1sd2YJ}UR?e0ytx11?5Y+as~)-lm1a|IV2t1b493Kh&$cf)d6q;~ z{{6^;sqeR&cWB*uH_>(qI_%AxH(!S9%?!z(KS=G%#LcFpuj8T5be*~3_2qhq%C#WkP?nw8VDvp_K=hjEMaxTG zKS}(oIZz@bXyDYDGn*iAay@*w1>|e|enqZ|c4vX+IhF9dZ#=O7r5%bqN4^C?Pk3MB zw|`ecK|8zePtr?X%JH2}`*ng+4&Q!o`MJ$&ziAQ1P|9RYnW z@;Y@DWVb^AYJyqT)mEwPHdcu~jK4rh_v5EdDJ6C0%$zQ$&dOOkX?jYMkY&j`LYbN} zWlHPz?HgvjJ6HkuV!}MMQAw?b501g_`a!?_1?sV$Ja>G#>QN1$RAy>3(=bdDmjw zC?0J6TAw+mbX7P`3dLIhz@e)iw3oJ&afeN0lgB0B7dAV0KX&$~E2f{r({}v)&;IvXy`)b$6u0@zQ&{CK_@X-g8I4g@Hm{IW=@;dkoR2TJHFqzlBkD65B%xat7^hz zCcUvotK|UThM@TqyU58%yCo38mlC*u3E(1KUC*b%r!otEbChE3u3m3(q~z@wpt(Nx zwTp{_Gn^&{tWLwutO?mi)B$b9U4SyfjMp(TPxDfAaZRtvQoOSV3uSD$PuPO@{T`rx zl7_dWI%SS-&Wskt7jre?-=1u@u)2<>Gl^1k=iwt6+mjFQqvrVW!fihm^3K}*$dQxtTxtrV>qIu; zTud!iue$KbE8m3doSp7Bj-RdA@n=Q;mZ8coZ>+Z3nd)v&LHHa7eLw5~sJ&!50J6~# zuta78oY(q`s`GoPrY7W>r?wiNAeVPYw}NM9T=DV{fEXk*Yv_vmmI0p=PhYR}*nT6~ z0?J5oN=h7KDceX*TqYCjV*A%x=$||{xFn$wqDtFjyN!vThplL4%|90gIvuMvx>Qm< zeEPPtf)w)~y>6BX1s5vMm1mG6_98w8HRNf)AsoFTMVPKrr&8xRv zR61vxbxRKhZY@4xZmVK$t!Ywy*FcV=Kxa>+9D6eoC8$4wBUl8o z$Env3`q61A?N(I4i zuGrGPPig_9fA{a7j855TxEJU=(HVGl836BUv^F^+tZ}*8W^U+L&w?u zdRpy86Btl)&UslM(f9C%8XjBTU$d~3+DIti6elvw`Q-a!?#ZFqf%3zMd&zIiFbHb+ zzeeRUcti0Eb7k5MhhsA*zrDBH_z^Bk!_sy*1Eo%b;#c_b1GN)+$ETV`;1bipIpLx{ z*)Z=R*kEDb=F@-LR`!_O+iun3%P-W|t{N5Msh3dJeA4|Rz4fD|<|cNj-l=ZIL<1Fx zmzS4`=KR7Yv#qAJIC3#vs4Li0HuMH1Xiv8^iCRZD+R@=k3-!wWJ@2Xy9Xii4?jy(F zD#xv^jwyjst5yXq2Jia%vVW{vb6|G?CONCCe!6?{)2Yfh^2G2~kQ%Z{Nw{P$+PQ2D zQ+*kmawwg>Tq7bJYq}~LMxQ(@^BxW}(IX4k&Gu76a^GGmu1R>?3Hy zTgmGYX{p}sQ72r|iQ8G(*?TZigj}ofE_)@XL5N5zqwEbsUl?kcZ>FayxwBw9S_r26 zH&(_gYTrfdDr*Md6tUVMQAHIK=H@ppp$7psg_#1%MnADrt9ga$;XJoNihcUvN8Of;3HkO5Vuy2Cr-01B%QxEbHL#AYhpk@3 z8ro?Y@z&e@?%)3w_KnKK#=QqLWxjac;3`-Mhru@)#+=wu-NlU2O+-JV1waGhW~8># zUJVMkXwo_M7|y0`*f?*&g8UXOP8#z^B8*1y3#jytUv$+find!t9Hr$)ZvXeSLqoU*)%IVX{6?9e=l9(s;j3rg(0aB45Y?KxXGU`QpwCZbru3W zEQbDR+1x3i{KBRme8mPFadLiR@q^E$%;nt97f9coqZZ{Zv(c)5?YSJ_Y8wryM9{J?f?+Vrki?M8P4q~}h@l>Fuu zlEWZj9aGOT!#c8i_wFg(hCP1!wj(IS&R<_VP2w;2NqaA@@>D~CW25TfO-$NvTbSPb zO?aNe<`8-k=qp7<*;AQ^6rJ1A-ga|Z@aBfqAx~Q&%YCD%r>7_AFOB<^!lTx=Zr+^C zxU2}R3>3Xmot15_8e_NKjxkzRU>Wu>5YXdcadBOo4`)he2`^vbHMO!&$*klpy86dw zw(TP|&0j|CoEv5P2CV7SZeh&7Al0*%F)6$T^`}7$DfX$G6I2ypRYDia;BxlQCTy=! zNK7@<(>vC$nod#7Wu1>-b=|*T9SDcm+WXW!9`SYtqK~65T)2?ILE2f&1Ac~gv=0M2 zTs)tCv;2IO)T*p_h*Q><3`*(@PHMw<1jZf>AVfV)2T$)W9zq*o6eugWc|EK6%I9ne zMJB|Je-!DUl3mOWfeNY%)&p;JfuGK>(%J(~b^&gA69LYog;0#3LP@_4D)Vrv+B!HEhMndM4cF z+(0b(lEb|LLclMn?#xg>_}(Xa4As#w3b21-!5gd`ujAf2)33SR8KTTu#WDu+8DuEU zC*1FJ^5{<6`Y}jBp1E-srUp?fA{j}CO5J?)QTGpXo1ebF@|lsbsxpQQ=897cUkX;% zRTTGSa^1Z_%)^0h-&_^xGTwSZgD;;K>Rm+UJ^#nnX8~_q5H5`(a%_iFk$Q{a?nniU zwtx1YURYe5Vl<#bW0YQvcMoGa(N3^NesjaLfdO`ZI0s~)u0UocjF5aQIFoA?r%-Sf z^c)@)UC%@uhPJ^szJy@CN)ozv>W8!^Q`1bBLmL+#_A01Dr^ zGZ95JtEK;a&w1bur^OC>#Fn$vf`ZsC(8&vq2ZIX*12cLL&w@9@87TR1sXF;vdcjg5 z^mcYb{&*7wBG@si;bTPr#B+YOnUV&25@{+;S#=d(Us*8(hCxB)1hBt5TQZHTp9a2J zT-rxJ%>md13zfCvuXSrd#1<-sEUgU(pNwCh7ap-je|6is;fz>^0;`8@jN8}r*!TsT z&25dSR=O4X4CkDc!AYuT0?v~kd?7i^X>vQicG1W0E2jnP^tw@%zfJpK;TFx1rt7o0 zPw=Sj()NRcCZQ~vz}V=iZ6_5qRO+7Td9Az)eOrQRLCh|O#dOT zn77tN(GlC?F|8D*2-|eup8Iftxd|bPc?)y%W0qDQ0Yr*@#jX7m=f|WLx(J{*M*lq=-y(tlWnOoip^z==F>@m<y$IZc$UJNuTl!T(e_p0kkC|+Y(|7cp3<13_-e|6%~v! ztT?12(v94=ue@c_Dv-L=#%1QIiB3917OeTYPMa}9{Fe{l@d{kZk{AJA{swe~B^E@A zXh*d^tE7qtNR=H6h)%^Xmq#r9-=Q88PQD5Q1pSSuO(wSZel+5KaR-Ip1AI38MCF8A ze&w_SirtWfMNX`6bZAs@seJrmJ-s->fp9^Db08EsKs^W{qX^diyge95#~yy~ii8fp zem@wDY>41C+$ZnB(0mU24o2L8cEDXwV+`LYc4jbO6~d74Bq+3m`6Hfgg!o)S!ZGWeM1CPX ziaK1x_egwYEeZMgupA9?CVIcyah07!^2^>DB}Kow7JLTfY&|^R=AF|DpLEZ#y$I%I zxXRIBc;yQ6B;EAyA1gqUF7UM@0>*JUWYG!QhIZuZ8Sut&d)rN|e;rUZra)i+v(cYR zT_>hd>plfQ-sprhLN?I)|IH4&ytx=*_>U>t9>af}k;#10{yU)RS2pV6J8Y z)9QZ3!s+D-3u98=B=@g~C(O%M7CQIg!+C-1t!aQQr5SNd2$0k(1X*|;BEb0dSB)pvZ;qPnJrr8fI0 zq-9Nqa4=Ai@h6o)T8Wvf{_CEbG0gB*L-47ogd6nK4On#Zzs$+K#N`3=PW5tQi@`#! zCVc%4gN&oJpuYeOALal^2S>`;4Kj#tasSE__gaPloeR8!ifD;LkV78d3iGu;#rTo#02?G{lgTUdS=VwR#_naOM=AS#?#EefRZ+hPR`8ZTh&~mKq zOeS@OzG2>33)cefd`DP0vPbBDJa`a<%m%%gy{lQLi#!?=zM0%Zf^8#OGJ#YOGs(;f zV73h^Yq0St-x@7DB_vQP9utZ&1}D3TV|BZA%9Tf1TcQZq6G(Nijs=QtBc+MUp9?hf>p~lR2fbi{_JME-~iRxGg=Y z8D!;70Jl|yCE{(wP;#zI+Q}f$SWLa@18QXBk%anK>GzFiAv|n}@#3ApYkR?3t9}Cp zK2?@FUO775=4`uD?Fai5SkyZ5T*FIV81HDV!|#JsC+FZ@6KSqCXmalHwl?ad_A1^?p?ihZ4;7>IlFCO z*Gk*OQ`R>rXb`E=FL8Tf z0iZ)E^bp64Spqye7M>05!o8Jt>ptGM+Pkgr(^e zl`dhW(mxg!DHOe@fQw6uin7W|boF=b8Xp!KYCF@|7hP`|Fv&NsUnJyZ?Yw)psZAQ`!Ud zCQX`%r}xRzr&A0E+O<0Qp>|*DCG67JmX)Lq-$qCz)LqNqe5FHIxtOo(2^IKCJjx{v z;}acDP7HUwz!=$%kV>6AkgvZ6(Gm(13lSzpY-O5~ENVgSury*am%q0Bj%rnbpHstV z9W^D}vu680hJD-d%d$P5;U;8h`A#$nVSov@7)agS`c*I1Xtr&crTHOZ(-HZ^K^HE> zGxWg4wHuyy-XDhSJCX^%cNti-MkVJSpLIiV_6D390kVYXA|WCoBAtTCrbe%fj(j6M z(sj7z^Vs+Taby(Z*|E&Bjnei}kjroU2lxcs(g&fJ^fnAGS{*!cWQj)ROC>MAmnlEv ziXQIL82S6!+8QG5GJ^CeTr#{~S%bbn3uSFrb>R}QHs5~|y+kms|U!F90 z*I9_JoKgL3z!5#*7Z(JD{3g zY5%P{`7>7fRaTY`Ji~>9YrzD|LX@BLPmS2~Ds0tvCVs@Y&DT-+%!5&5a7AzbFr$c5 zftN1L{a8zl^Y>RWuKm37sZTi!hmlK{F6D0YvJK)|o5N@KNT__ZOXHuy4AmbClW9Q> zNqB}N`TpCAd4i$fKR)*yP>;N2E8h}yck-9RXO{EKKStUuwLD7<45KT&3FBhUxiPc{ zsi4clDYz667bWiz{_z_YHR?yNmvWz?cO46i~vXZ zXWhDFLT4h>sRUrzlobOJx=~0S;)@>TD$P!{V`ax(hD50c8k>Tpo0H|uxC~DbD#}Tq zby^U#%VM+jiaHYhMl`0^dEbAJAbWQ30=$K!J(NFXy=z_VZ0+t{s`JvrC(Fss}9=`Cb z3Gsvx>$wRHW{}ZmO}C@(sDyl=HC{s}E~gel)s2{yIBQo3go5NK4({G5G*UyDiJ!sT z6k1DanL|7snGa>3L?7nePd|L{4S(>UWS8&Uho8t>r^xNp#>XOD2POABa;Trq?WwMt7$1a?3r zK0CBpJblD-_|?02-dcNMvd<{8KQ?{&@~()&*8tD`Wev@pdyo|Bz3Ib!UCXV*dJ}3m z6DGWVa}t@sAo{ie;!QFO4=j|6l(-|VNAmcVBA8kgE%Wz$=cKN5zMGYnsi3AqgY;m)p(Oqq7 z8h6I>oH)0r6%?5(MEb#@nu;w2e|@m?$~I!rQmwu3zZGn*i=mTMC@2d~42h)qIS+6n zLsTRvhP`b2HBy6diMU`A)pq|^x^47)pLfTRrzvO;Zd85g7?{kv`uG^?u63)9rq>mw zzfd<&UmqMlVPA4TG$v#@X^N?96lj`ZAR|cV$QhPpGhmSi$DhGQrr9#+?-*zABmWaq z8^~;ss*014jqrfG-f>&Uzb-Vj>!^uXm7obaS&vGU>Me^DFETz<+u-ABJ5F~SyDnrU@T51~Jn!TqV77Lp0T;Qny-tObO^=pFEfo)ny%SeI0iP&G3p|$VgSX zQLOhMopjo^-Ms_qTza{JE@}8!1-Z=b6D?wRKK(E28>d=FFA9b-NFsF~KcSxr)u1#7 z&~`vCcotKYDx1?~it^zLiB@-Q;Pc3x#r0XIH!|@ovT ze>wo|AKYPE68L<0qOMO*BOIG%Pfj2>^bBmp;i&Qr=Jbfnm&DchuK|jvH_0S zTW=km+IyStX|z#kFr^lqm<|-Xl>u7{77m$8mlXH#)4_jrsIAXC)zvzc*(4$aQpX_r ze5%WbW;rX{jBnEXJ-@HSp{CQlZ_XY$Gwcr^kiEN0UK{@TMp12bd5kU%qp#15|ojrGv~zE>}iVI)|y z=-&9clhMkN?WC=w253Sr;q$U7^Cegl=X4yzdx(SG0QZuqbl;y-a0Fi6AX2Jx4>@YzD6#w%u*^ktk+~IQPnv;NW2VbWUcTx)A-2h;7BT zS^+2cSNl(&Uaq_L%5(d=+8&I+Vrh{?Dv!qV25u?$pjDhlzb7g&RC#GDh7YfEZP6`h zgBEPi{=)j_y~=1ddSJZXfx^B61`J4NopF}7P&ve`MImq@TFXX?EDpx9h@I+G^g{3{ zeLKOnq#usmM_>||A&*k*=0L5cozp8F#R?`taAHEdhrp%u8cH!BPea77z%Bti72j>* z4361=xWY*2lK2Uzex5jpe}iZ4<@VRVmy51Y*2_g`WQ4eVFOSFjKsIKfu}#OB7%j## z!$LBd*d@zWI*_LGRGA9v5GPqe$9SMg(YH@y@din?rf{PvjfeCIyy<&ukfx^aXqYHI zqS?eoLl$fceUGBo**&GM=ka}x8%SDcUjh(8d`Z9A7eUmZSL#D8{)ZS4Lm?_A_~H-o z`vN;lKsoGn*WOHS1oeu3FQQ~b!8YIH0G{Aa1Mj07Y7X5rj@C(p|7b&f!?db2-?ID^ z8QD?q8E^zdk^_9NP+VFdIvhG?NmcwFlS8qP4QOepR2u-bX|68=-Vsm`dRxD|*G`7? z#{k&Rqa5sf{B9}@z9DcUSPMhesLNq-YzuIPrk2AS+*$Ky46>gAoz?7m3{yiXA(jG# zwcD^v7jJseTwpoM3@DUSipoXLVmTUCQ6%#;p)fav_$#`%B16qFV>-}@&LDY+HV`}_ z2CD&XK3YbP!G*Vyztz@La2g+Rk1Y#7X#?Ym*MVT1udf`>vVkA9URHRmf7wqVI*PoM za>_%|@6t3fmrTx_>$&5&Tw3Zou~LIl;?Ln@OIs_N$QAy!|6+wo@^|G3cU6!E39+(jrLA2w2viv^cj)l`0sDWd-PFa&`O_)vrzj7@=`hfVy*y$QJ38K5xyg0 zJ;Q__glr%0(L93y)&H=%G;)m}(c&m3al{?Xjs!xa=DVs`A{j@9Yk3k9Ob!eV*d&j5 zZJM&^L%x;p#p%v`w?PCVbd2=}M|3XOO2@H?$*Ix(Me&C#0_THI2+#~nn*+^NT!hiX zC=^l(Q7?**1O;qNOpHRt8_+7OK*|0>!78`%7|#o9H@5L3DT4Su5uH}jL!t(2E0vLm z5J3@rZ8iaJ3koM|E6P_;e8rC6y~ywrQckgjRLJIYdRnZ28!*&(gs+#Pdi>vwn^fwt zNbteBrt>*P*TlWaVFDoQ8q|J#k1pd23;ZM>A40}5ou^(N{?Dt4n1JAhHo3uYA$zQ?kn-jw`=#uDuuj8pk9ZY8!YWnrV0oK ztx2@jkXH2krrS6k@H+EXN!^DTLMOo^Ar&lQp)K16$G+Q|P@YUEHbrzOBu&{B0CN*^ z>&eAQbcm+FKH>q(`2aL?#reKFj!I_dZGNXD3Cy_Ss`F$%h#U#oA zb%#G&zw@wQRzSYYx^4yvF!vFwe)t(lu;35gjs@bq711dIaRsOJ_#cnstE{a@Qj#|x zb#Nrq%!a+H&yS6Ef8|&|I`suV-n#3aHIYd-tv&{qS0M zwm2eF0Pd7nE)p(E-u7S`B`bRfMP__2apGl4Bv3Y=$y6578DEEu>a(Np*nqWK0Ak|h z1pIt_v_?;65z>aS>Qt1*RzqM_V!9%_MpE#BdH)U>N(xSb2q6TOPw|fQV}t<@5097? zA)su$y(XwdLZM1O_E=7&6VnG4Q>$q-bH*s(>7)BRLErl*0pd|lNnjG}o^OL8a~yZ` zIg9I+0Yci|LuxB5at0~tFgc9k_degUiQNITY`K8ci4S6nNtjd<93D*%r3oK&+ml;Gm3o>1t}E6$9H0j`5N-s1=r^-%}w#4 zEHuG7f$7qO0<+^x>asZMOrPR-O#QMcW7Afz~#Wb3+RtX5&D8p=;*!*d6P zSf0@SoAMgTchlk3^(*dRyRdJF@Gu7q$iJsNY+i?toFRua8SNkAyM>dbQ8Z zttrLz;KAL_`uaX+X$N3lOF*1P$3iCVxnEo-ttS+b8!F*#^tvzFbf)?Zy(BiC?nK?* z>gp^c?uxo{Ix*5VvJH9OmEdd!k|>h|@i;tmNLeOfDF4I=SfVOPTzrhh6M=~T0#q+` zZy~UqFiqtH<3<8cSdt#Xwj(sn1y#KbtM8QYgY(>tJIfE{ALHM=@ha6DPL)CpF%7Si z{$9&&a>sKh_r>b~f|*ajSRq0O1uePWZQZSQt-eaBNS89l#hgZ=C1&@--c|N?bPMA- z!Ud^Mc`zL@184vF`RUDZUa`!Poie;shSMXdK$U>?5hGk+QYofaTkdNX7wye({ihOY z3^O2cXTh(t1Kx^%Z5w`liRG__)U;146>lDfYdzq+Qqgyn#mQWFfnYFQnZiAy`Dg|D z#V=J#z0!^zhV?C-V8IAr5D?H?ws|oDv+g=Saca$`j}JP)#htu!rv`eq-PBJhHT$>x zu3(j5{Qp3TmUL*Oo;XtJkl011_=Bai(&a2E@gA-pd`F+OMs+!3Y!+Qa(n(WL-o= z65j~o@*egZ-PIwvK>1(+fwDX~s&T9?r;P}PK@h`SopECfCqg%9UKfg&=_6K6PjQo3 z6=55Wa}q_6WGEi3KVTN(`o3aV53DTA1H!NPW38Pv*J^Jrv0kGxnsG3ws6bCtEzCXV zjojMLOqqv=k=PXRWYBTkkJE1~UhYh}Je9Zf?+BNa$}9ks*}26L?65jB`IZqJvNjW* zt;vdnN9i5%*zKBa9f zDqs2p6Vh!_d3mRm_b8(9oEa~>q8O?N(Dh^K+C?@H<2hVE?HZy(@KY&wKrTVMI`T^t z3g%b~RziFb4MW;Vgz3XKrpDgM&ZQXK8dVHL0(VvrI^d2Q0)G zP(h1Dkco?r@0-@i2VzSc&93;w(8LKxh5Q7!I5?~MCVmZ`7A0}if=2|sq8I{=6Axvw z#dGfaW6i+}9abEu_~aS+Vovw`f&y1@CDC!{it5stc*GSRs4=dH0gk{Ae0w1qQ>1>$ zjGoBKs9vSNlyM}(7;qYJUFo4AL!bd-RKjR+NLu4ZOcEuV3cY=_h9k}ZG99MBm#sae zcN10Wl(fekn3zJO9QLx+nu{M*$P`}8z5*6O$EcYw9suP?pbmaYu?Rl(qc z(pI!I@;Ja;o&U!%;F()F22kv>=+4o5L>*U8!2*LlpynYF$=6b=OY-J^{}kCx{!hRL z)EI~5gupOyR@rY}4kKQ6lp`&OBQ<&W`YQ8+q=FNF+y{0P7O+U2nH4l^*~a$A`!Vxy zV34Kj477I}UuYmdS=jjeWU0N0kyBCsknWG$zW3lqtfnepOUsGRA!t#slJz`s6r>;t zgr(oMbYJ;C0La3aIJV1Et4y~z>=--@Zs&fc4zeSa)>*HVEq^ZjPVqLQ?ZJxYm~yxd z=@>lPsxl!~h)>led#|zv zZ%KE8<-)qjv!7>hipFZq5j(h~xSaCx4lYi+i%=iJpVaUitGWM~gsQ>eEO}{6YD){p7@t9^M(6X;+%_ZeH zwDLH8+MV7}0Bz^^xRlZlEvo-Sie<0_A(Uf)no}hoMLRjgZm6cj7NrN&h}h0z#3IWW zV~=_wOgIX=PYkuV=Az3&A@sDWMe>sEd zb!ImZfJ5L1vM?L3Fh3LWzB9ald(yqwIG4PObUmtQJ)6!I(?i1t4I zSRm3~2lzjqHCoHoXz zC$<*=1u2LikPuUDj=9sj(u|PBHZPHP31dctx#VXyd(EVTO8gPD;H*2>BLEW#A#O14 z5WX3}gd+LIR!M+Vt0-y>Zh8gXaz@+Ok#cq_MGN<1FSThlg`*HcC=GCQxx7U4+K=&j zzOUXhaNa2ed@TVgK?S*>vStSu0*a;~{ zE*Db%41Og=O!coC`0_g1P5S+O2TW6oVWUBz66{YN$02-qVIbK{3rI##^?_BHiHN&>=Cg70+&n9vldt+y7tcyQhH#?q2)JC8}UoT?NK&@ zfJqCkL|!XBn?#X7c^^Pnq4W{sGdN`~2urVio*TaMh_Pd3BZBc%UXyOvKl3G%lZPxy z)9blt{}ofpW5A0LWL9bw)H|HIP?#EGBn)klbW_en2fHC}v8tN1{b(Y+ z{*{Y?=o1hqyO3LnnS+*GXC&T<^@Eei1*Hy!%@$9Zq<9-*e3blZkHf!WHyI-*)2U7t zN53(8ZX+u>Ii0TjrqztxD_A$H%TVd&z;gq5krrv~NlTHt21}o|+u~@cG%6fiv1Ji+ zds_A0dl*5PMjF6^$LGZXBs%{0osu2_R@aH z+EY@n+!ECeW1zo?v2$xmC5pB4M3B}E>me+Rr8~I2#2mBl)x0Yth`l%`10DTVtz6m5 zbC$?8B(0F`ACHPm@ejn@iLb8Uf=Ad_wxuvZxbD{3^63`a2@@UbYIcwqdP6_#zIiL` zb69C|MiPIZXE;oOdO;2)gc|YC`SWXib;%R8U4C z4NitbO1SuL<>$6;4H@0WQt9X~4|R2MnbMlNx#1UO4W{Sf*$12=G*E0t>SaYIwa|n< zA|RjX{!O~cg@qS}|H=3I5LJ=WoOz%KCRO8q7W8}{epxOMWn9|^mvDN(53jkzYUskFI>vx-(Y^pm+DuWFCZVKx;S#acAg4l(&MqYsBHF=?`t~ z%iG{?vL>*7`x7DOddiEUQ|Rsehh=Q?O9b=k3~qFEPfJlLH3rUYQLYOyo+i6li|j zi6SP@ok-mJU3$!_?7n~@GoOL=L=D8q+Rx#47h$+k5mgt(w*r22e_txFJgdjgbdG>c zG9J)I+>EM=AF0^>c4t$bj;Nab9RGYh+o)UVT*GkzS8C65y#Eq_VPG?NRdU`fxG*c) zhV3@xLr9nh@FJM5^%;NFTZm&#b$ZsClQt6<<=Fd2>p{Q zyaR}tZh}D&1`~PLZ)WUR4j|mNT|03~N3SiRpGT8K)-mWhD$l^xdolb5rEaR zW)*5T?*8?{SOCL2^_gemwV;L^p724})%`W~ZO0?JO@Wc>gi(&a-YgCW?LJ{zahLe?ar+0<_^ex!=f}y8=Wh4NhjkYT<4?Qxn6jc zQ^9xiQ)4VKh20-3GLU?=BBgwNqKceOjzaX{Ylc>dK3jcL}7YzgU!j;bQh+j#5mp>UjANx?7X zf24PoF0gZ!Qe1pwfTmh9QWUQaMyX<|%KTL=KN`lG>jr6{19(ciTfdeW@i@;4p@)8H#Hdmd`En4=^CYu=f$EvPs!lj2(|ojU_W*|J9SAR@ire8jKED$ZQG{>#zWWwC)ZYq+3+5qzG* z+n&|0Eo}8EeLMv`p^thOUWXt<*vQgG0XK5L4y9(r^EN}7O?j;W)?p~)^8Rj0%Cx-h zM>t50{NFeK>-UHy43JanjvWUK4*|`le7*X%XzH1e=-RHp49_?G>#O~->Av| zr+o0F6C<`%R95C?G!J8F zKKmT2tZN}rQ)G(5d`9*{fQc-aJk2ahfLr0{`4#dNG!M+m7X($gm@ilXKi2ESIC;gb~r7E@k zwHDqrH$In_>vnP|aUa0XNaDYiOnAyo(`!C_w$`kz$Zb^|{vIPG(5CwEq@*!P-0_MY zuRhJce#SlSWJR{CN+=!;g05}SFr2?8AI>E`@Pgo#foh2P@W!2;w~`Ox3O*rfaS8@E zKUY-!T!($!+^Y>2wrkr~%$TS)W-2ceK`+eDN#s&l{-o}aTUZ#f{o@0%%0LVNgZ*6k znGo&=9yil*n-n63T6DEu_XRxPG7GON5htu%={u*#B*&Bsb2#}i{NS>mjkYs8VGwI7 zVe5&_$<-^Eu27(JL)Hy-9b&hKiSUZRak*>#99*aa%(CP5GyB%d;Kxlsa}-r!_b&_m z?mv9^7r{pC@&4hm&0M*C)Yl-7=ihDXxSxJgak?PKVE92dw>@A%SP1T=B@Hf9XC+4r z*nV;Tvu=AYjd)RB9-o@pQie4Mh2JOlOj+NG;vn&M=j98Jt2EnRpPXfCmJ~6_&x$nZ zR9))BcS_rCEQ1e`$*6mr^eg{~tE`Nk;RG+kMI!zF4T`U>Gv7#i3-II0^}Bj# z+AH3m04;#`r+c68hp!6}Rg@D7k9wQ(|}R88hZKJ$eNH zw#~nIJ&eZSq{buL&2a4|pz{?c6|WiXSOF>&4lauRjJ6GK%?tvt^r=u5(uEWzlG^zr zmr)KGZQCi+)R@@Vdc4zQn-%BJC++U*@Y=fL7my~LB^vq6YfMRWc-Xy*dRLbBcP?>_ zfMBRPOs#+0rlM;HhwM9MjDgnDR2$gfz33*Ldr9b*dX`E=bnKBK0VF4;jnhaZkbtZR zr+v5T#7}W%Y`{=3YZwQ;$1lH&*#&Ud$Ci%BHgT5cMLw%~B^urbdxJGXA_|F-QrzBCg!5mZq>1S}eqkY)F?St$qR zTIB6-eZ1zp*;TsB&Q8`Sd=F@#kjTupmR67|>+(ipNGxIxL5=imRMsg^Pifb^)}%y_ zDQkG=_HFyMHY$US-^*7bN?Eipk#`7z<+5Q|(DGeOm1H~RUwf1d?>d%R@(~S6{2(Yo zwmi!B_;xmmP|e$!ALy;itVoc15SGggmAwwnux?zl`@3c(+p4a-b$;=@U|&$-YYy3) z<&#vM-hrYg?(XYRQ&P!HO1#AJiAkVjrG!ijXhhX^1Ps3qWE{%9!jIL5HDzZozVLCR zOQ6se68QmOUR{f2>*UR4y&0RHUaNM7e(>6dnUkuzhlfAuESoPhEdA$PGx3=xwh{h_ zz+~BY_Z~es>4ycpBg`hys+XPa3*Y)Ct!T2%1cil7hyO0n5)Xn)4NCKZp&X^KTja(v z&(e&m>%Zjs0;arHT~=?Ma*2#^DC!LiMV`(+t~q%VOc`)8=Dh%WN|t}oslcNUr)s22 zR}S^^KT#T%iYkgsC7`%8igMdh{V!P0%7qGdAa+Kts?whfD#cJ1zK^($A-4qQA}ZfY zn&cJ~vxwDzsJC{5Oj3PM%?))X#0rl%)Tz!9HCGPDLvicLZ4nvK7$-zr>%M&_{PnJO zA)n_>b$s$~{y>>^VTtV7$-TH%oD1f@$CjG(3sVf9%EA06_CDS89^#IUcpq1@U z&Bu~t(%-+&H|*lKf;9oUv=S5`LR^!gow-)fyOsG|!@blf$XKjUefgpc0>^mq{uPEr zjfs|g72Drhs)t5Sb#zkM_F`a&{Q-DAL3)kR+Dxa35F0dKEnB^z2t0;7-}( z21HGPFBcb(R&r{ReY_zM%C1eS7`3$9iH{!r?aw*w?=9277UJp0ta#??6<--<@c1qp z{8Rzo=M<_Q9$&J!r$I?kQSv*F<3MGYZ0t#@JYLpYslH@K0&E-%iD743bR8YTH!5P@ zETC%~KvUonPpYa%->knTY?!TW{H- z1U)!d#AP}OJDf1$+?WGB91KEt@ACJ>Mz}2oxt5#^wckDM{z?6VRUwUkeWZWgx z9$>FKBg+Lrm01in(1hel$HFL3F+$FB>?6E9wNMB>XQkDhq&=ecWp68d$Sc2V)}@E` zmU$FmDlRc``p|a{U`?uNvuQ?HnqpuUKXrad>36)5+qhQr(6PT?@nIr8H9S=V$GI(n z)fSC?`}8r=%wfxKOJRdC-`ryD&pdg-Ndc07J$PHaDBocItAsCAJr+CWgJHS&`58xz zcO0b6WDvA=2H0tDws-3ErGFDaAKh%odq=M7(#n>}<|76B;@x+8QsweG8_EBOtXV1z zrHf;=q_)gLX~U+lWiDm>e6^2g`z^fQr>yQWTD#io8+~D-N)-uSY?Xc6^&XdZG;-ze zcNz^>{B-sK{I9T12hfn75jA7I-S_^}K5-`MtM2<~r-@Mc#xx#%B7Vmp*NWs}l!LiX zt`P)b!35sA4&Wedtl05^szsDF%rA7myY;^_6RlsqeraTAc;%@bCtp|)0J#UG__;nU zn5i|BV2=V18zwS<%#&f6y^iVz90VE_OHl@}>EHJHt1{a!e^S?3Xm18#=4rV*u{Edl z?);wpAEbIdj8hv&KNq*4k9| zO52S>G8O!27oEnr0-kR?Icm&}SFhXYXNFrfn*!T_^4|rxhq|PPXl5b1zTla7H49@N z(wiv1;b^<8XuAO4feRNen&g+=$Od=xJrHf(JELGT(}v2pbu0B@#w z>y=Zz2QQuxJcZ6^-`zr1M@POtc)G;yRF#GVeL}X7%+#|eCJe2rr01ick=5SBGm$s| z(0}bvz3bhH*hHxS*R07VQ0*!|X^0Dnpgt5bVmgx^ZgVNMIxjyz==Z;Yhes{~el>v~ zh%od{K`Albit&^%^8?3SV1^h~ncq#?FnkY_OIQARdba#7{@MpcJqbK6$U9WFva1ap zs1#7QD4H~BLhWfnohjwb+tsV)&ztuv>X9@bPX++GQ*2#a-uF0j0#iEHitk}*`ny~N zrROlk(!8T}xVA*ouPW>YK&_AfsP;@j=EO4a?N(3>WyR><2Ns>1S$Sv|zC-Ps?XOUN z@rW5dDE}Ui&ad5*9sm)qW2j8APny^#f>x?v4cNV~(S-J-sB1BG1?vm6(#~dJ?mKdrMG1QfOyvbe% z*Mz|Vk*(w)7yFiFpFUHvysXl<*@dSbnVBF-xF91!OOj<3jTVV&uRg#Sz5(i*Zk1iQ0e4;?(Xc=nkft#ziR^(nBu8hgGc ztxZw&$~m(o>czu>We?9ST~g^wn%1}bXGpg#gS`t3-g(+pRi_jdauu{aSYk&N>K=KW z(C~tgDq}qAt@bm`)Nh|lYHeF{Qgv{&cleW)leX^_>%d||@67U_)jH$k&EYR&wY>HG z6VPjJj&Yb+viNs$VVDeW#9R=`WP(cw@ERlqz4aQK*E#IDTc8IzB|cY7W7K0hckiwY zVsWSOw?>nFmW~;7?sfTud|)?JEpl_DwiK5mbiTlKSI2aGGs9M8%$RlhFMq9k^=ra~ zl|fI_2%cZ7BKxRXC~Lf&IRyHmjB42Z2m9kMdlHtDs+TF?g=yL$->n^XZwr1>`P>gt zh!l!Uo&5=YV#PjhfY>F%rJ7@dxM#8WzpE*85X%)OTy6<~I3%l}ogW0^N#$ zrA?bX9IAMv$M6SVe49Qx`u3`U{#6q${5WlGbhO0CzB12yd4bc*t@9teJYih8tK~F@ zi5v3jj&|47+b~zd!Kfnv>h&%QXiw@*M&ngxtB*f6FDfmKryb#Vsrc3$P6K>jv4>-t{FBmyU70&^z7{ zZ=|-|s7LfEt*IKQn7`i1gmEWM`n(0%7b}4Y)HE@Cf2Pa9gLBW7+fv;_>3N-QwUZOWtgD)wZt*#A@P&0bW3EbsSm$}WScl4)iK0h4#DAjkYtz3)XPz6*+ zTv8G>H^*Z&Li#!Eb>e->Q%2J>s3+wk$-4w|LCw;Hh_R2T_I2J9;z#j2W2XQr_llnn z3lW&L0`#_uO8pQ$m#MStA`#OGG0fwS-bc(0gjAh6(pZAVXO;2WgOS;63z2|)HLE@7qb0a4HIl-Dm z>r-)Zd;Yk5h!0-AVnsX=;+y3W8XK^uR~mjMM#_~n45Aa}7e$U~Qvb!z3OQr@ailH+su>4-xWA+?B8?>gPycYmVTu>M) za%N!M9-&~ps?|U_Snc`Pfg1oD*t4~tYyf+riD(Jw%!D|h?!U^zI{(dwv;}Kw->j|u zB#TLU^-?f~0cQj19vI3LLH2;rsill^8eC{b;1V~Ksl%sVs(zdObyS~5FRJ{%%(a|` z?BVfo*)ju{<+Y35leAFVN^BPoaRV?C_zY2=4on75U56Z$ zo#|k#vNr~cdrI%gTesd$aibTUMzaGaH5-bGf>pDDIw37L`^Da88hT@ zo^|m!b+t^#s!ai@*-GXGKH!{Ulnu$@-4F$jFEYxv5t6*( znN4ri!pNv{!*Cp;dDo1K^>v&u?DHR4ci7r6b(6@huIRhNH~(_hg+eeM85q4s5~#B~ zwEV9|5K*?Xv0F)1nk z1^eCUf_**fVX5dwkDcR2IY|Q|Eb;0;-$J?`6-o$TEGXa9KAvB+^=}Sf)Ve&g8%sUv z!%2u#-P)W|$1n3y+^^w7ZkK+ek+jf0>{ULa@~I3%rS6oueO@@mYMr+lQDqXX0sz!= zeBTRMgonM=@hvF#KBmh5JUtaq^6=V^(yi9k;;WcowD0fy?Dwbh)#oPe`Za6Q<}D7< z1u16h=FHu->p@cQA6-FS_~h;Khpjk&p~wQN&uo11&9%O)J?iwEMBqPEDcC=MM}LN|+gkAQb#K!cPM z=JaceX>^VpAS|_0eE0AyL^;48U-#T5ZqGNyg-OW3X zsdYGbum@QDH%J^k1%n2!3-r8PazEzZcbz`Bhd?aDGZL1m<9af}p6(Ow(O--6ufgxK!B9z}=!TG!WWjtGZg^Im(_t_*Op{T%<5np^c2E%HdpUXBgo zr-R`}Hex6T%lEaBHTm!AlH!o{N~0l9G0bPd#A4rZDU6Fi=V6}*C)p$=Cc4#SVo^lQ zVQjQ&&CZhm)wG5)LpH$noQ`gcV3GwK!tK9YvjI-2cwJ&et7zR|5ULS9%_aK!Czfxy z6KB8Ri%jZ0p5)0&KYdk$ZYRSSk4i7cq`F^Qx3`&l%J(q2jjH**g48l|LjxaQX`JD^ z<@~!Px`V6RwjDcWlBvZ6j1~4#p~sa`^U(NXf@QGgvX(-|=`4A^zU0N#XCh9;B{{No zaKXZ{(=3z5OuAKxE5T#FgtUVBH}BMl$+;7IO{l}4V{$>!`bod-56SMOdzp@B z)*fxPsjvlKQAie}TP}pleh@0gj z=x&KE>wD7|=^Lk1zO#mw*3MPwb-~&Ci&T83zT0w{Q4c7QC6U~!^Sck)=3X8C@K+aW z{-&6*VjGn9jiD2k>E5w0pjy`r9{RwcYCHDd?rE$03t16;54EDGMTk+Z5HGHdnUu27 z_H}E(x!lnEn%(jS;{kKf=3ah3uQ@NS8VuneZ#Y}3w|M-Rk0Im{iwNV>97Ax_W-Nc|(Dy&i6iqU>O_*iEag!Q>i8{wbFX z0s+}A#rJh1!-W^-Qa=kCb-1u7+)K^4C)#|BCJt!KyK<}~JfJ!_xzKJ0J*@cbF}%O} z`5|Hl8LST*mg4`BN_fo4JHVNVJ7#e$LDN^u!8p7N5k^T_>OaYUdpGA^!mlN)!^>E_ zp`mW)vPR4g_Ob{|05wJuQU~l`GV>iWIGQ6S;nVup2yuy1rwh`d^<)P8Zi^O5YexE1 zRGt4C?z2LslwcylF9sHSwX^4ZKhL0&s9V=B*l*6i*RqiIgL>UQBs1FWM0&iZj*{@E z==nz9;F1O>H9JSGD^gB?oheMHWhXY7Hj9!}M*Q?q&9%-VhbAt@(qjlOf_$6}te*~b z_S~?G*09?eXtmoTB#xg9VRg>)%B7mcFkS>uBDLajQ$8VDrjxRH=-5J!d zabr_}wNuF2bgI6Of`XsJ(c&YXTV9v})GR!f>0r>z#FZ(E_zGF{X9!TvRdm^on;<1c z95Y$QTw`S4Dy$UjEoE7$bU5m}7vaYuo2|L7@h~D@K&A#ObIY9UVB{Vj7sq7!yw#u? zfyfDCC_h1A#IN1YPvv)E=20isxbefp1+Pf>?gW3|fs3w;@N4(yG8Piz5Rot@@iLPFIakfQ#}CdWoKo>BqLMtlLq%988t%+L{BPX+}bP#lzn-@Az6pKM2` z-+u&eLo_ntlg3{L;(jFFh8URi8N8)n!12c;FpI*0rjdL|G&k3SgRO8dY$TeO@b{&~uL4m^&@t-GcJADfs$EKPpvH6f{{5 zmg}6FNf5mW8gvNy`HQaM`<;4D`ylicf1)E)Mu2fKy%A3a4wngsPzLzua2td3(W7L6 zaLAV?I$yFDE8%M&LZDt`69;Ks-t3GXePu~-S2U}8speNRl#}fRoRHg5$GwIcmIS-D z(6Y%t)VbH9+Kxh=&lyC+xUKDza0IPi*tjEi5rkUY7MQBY(mw%Aagv#om(F)$V|Qy> z#$mr>IHWz9mZ%^sGzh+D2w=k}a#eAxm|Rc=h@wJ_UKVAq$;KZEkK*|UA`nCOCB5L4 zvh>jx&kcwc;U1%SnPxK@d>tNK(xv=cNVNN!Oo_ZO;y9@@i5mK!UX+%8LNK(y9vhe$ z4a)Hs8Xgsp%uWyYwLaO3z#-f~k&}_BW`a3_xg5noQa~I5qy;A@_|b#2L{o{*s{NB~ zDP{$(z0;21v+57QC>=F1p>=f2hejk=&f~g|5AJln&U|P~)QvF{w;h+6uhz~H;4Qn! zY3E_zOL!-;wour%ENt;t=#I1XbF#1u<#B%#e(mp{9}hAhzWiEk2{R`OIzjzl$ifm9 zrh(LgoW}}i8O1;DWu9l`BeP-XJ5m)9^@KWzdB zz`t(>5*be*;_(~;DRghN&0!m{9>91qyjz$n4H(g#P&H`AO_bjA*SV)vBo_;{7Mjh6 z8eG+Xyfj7~)*1o9w&OhJ-#00P<;Mzf7q4pR5!ph!uE%z0;LF!6TS_q}kH^@dtb^pL z$X69QLoQc)p23`11PX#|(VWeiJ6E>F%MJ^f7l(v5mI|AVnzGi2;!VaV$iSu1mbura z&cQnZ=npD5L7{7&bQGZ&b+_y=TWCDEe_w7FhcvvRriA5P zfb2!^+OOT7r#6T)=Zu@6@QKJhqF?1zni}m?aeaVPG@-7xd?L05T?wt2SalJESrotz z3!@fi+$A3f2NTjZh7L_2q9BEo11I5)e{nLVqax)6!+%zDWQI*Og`S4?j?XmmF`@c;a&@01$)!zKEKsJSw1? z6~lkH)9g*$t_5GFYA%GF6O?46#@x5#JEx(+LwKv|sPI|Pu-3$3RE_|muiw~$`>;tiIAGcU?fR8A@Lz&Xdz)C{b>6szzZTgGR!v3-iEwgC3Opf9$2mL$%0s75BRM>a|NO4rh>BT~iw zhjs%h3KRoeFiqyw?7a(c52x}xt$VMgE-&i@&rhR#XizT=zoOI@P!8x!0Q=vp9Yi(M zJJcGX^W1C3`Sa&(aoi%ZkTktgKHq>_!h21MYveb+Tb$?YUzRv*4_!LGg@(aXtBtdk z1gjKFjGd}fWOJmmiW7jaa8n1AQ2s-oxuMr5SuFGo6zkyf`>QE7zn(A$p%?hIYbB*2 z=aSK$7;~d{1T*v^6j&l42T_23S(2`|0b+qbsBUQ|IqBl z*y;bpsKx?`GN6QK3?w6Og~;vs{$rlu?W+3U_VwOH2cwAmV<7&Kd-FvYT#NHD^%7_% zi#{_L))%V>5DjXQ6+7(Lvqj7G6Q(gk2#22+g6$Irv=a5~lkHdh;r}P*hH4&?FhRKn zZbF}4yG=e5m30Y_4iC~l36X_DlNv<2WU-@c5I*PqNm)O3I_`1GZeWr*pG5H`=E^~D zj{5@ls6nnz(U(jo4g3cxM$Ql{UtEoxXy__3e(^`HxO5l`pS)!hcI}!a4IP9- zb&?>MgVf$}SqRZdh=3r>6DyQzQ!Kgh9D&ig@Dy@$kB+Rh1z#1 zz7E9>b+`c1?q_Dw93xU7yIdQ4_Ze)pw6?D#bT5_ryNwePW{YUPk=d~7_gPAJgfC&U zoSNFIk#lA3Uy+axxzTBKY!dFSCmPq5ZfVL=4n1+Au?9sTU9G#C^MjP4mOrJ#42W_O z_B)^Nyw*Vj+;V96Lnp?20?-jzna(ks(e7M?m9PDz!WNegV^i7iYR{6PW{5*s z`W4eWpFMq8Gc?E@$d;wW)X++PF>k(%J2{R$6B6J*KAO_LP@ry=mL5WST z6YW$YmM+EgP!bmGr8rZi=-3$Ds&PdCli(GfzZ1_9dT>ebEk4B|#PZPS1&k#n-wN2{ zNT?&ya!2kYpEvB;itS|XU7nLB_KUW^gM)+VoFX|=Zed#T9AN6{QnYso`1 z6%>>hl6qk=Q@=>h6O{m|M{9;{5Hp@^15;KaZllF6EgV2Lz&+uQG(m2i{K}`T0%nZr3|zD(BrA7ooFdZx{O?-BNBZ}x#O8X@+5h~HzyB0~ zI2Tp#f4`=lW*Gvf`R|v8B>NedSO5EEr_Mzy_`hEzHUf($|Mw?|h@36>`;`9sGZaEl z`u_JP;NQ4*=l}16{O^hUKbFY<*2w>M&Hv7k|G%BAe4n850GNVUs=R~^r9MG}A&nNO z@QZ)m|r`J;YV%MG%C%#6{SNg^w{B|^r;O2yqQy=2}snQD7OEcM3p?1E4Pz5!E+( z4tX*JfuqmDoJ!P0G@OUozrV}F%Zmv_5i=E7>Mk|Ik%!2z6Ga}}YGC_sQ7tEIDCb@? z^`76g(ONWSWQyb$SII{=aa6V`o;F_G&0lfPe)-k%f_gwm)3}20$Tp94=H_pk^?A zY+)sx!@B)iO*uGdyjp=7CPN=gW=2pvC%eS>I4O>WP$<;A2A8WmfJbivpN?`+Smdo? zu>KvCT%9z?_d252QT`JaojdWzC(3391Dzp_zoxbUX~}4%k_La>+%)&tr0FM7GxITc z@LW8>CULcqm!PvL-1LVdIgc36!P@|RYN*KP*~V2%jtIk8h}Bjz7)+czhWqr zEwL9V9|YoM6zhlw*~F{GR@Y;XWnBPt5c93>D;UqhJ3StbYSQ}_-Xd1iaqN@=AO`q_ z04YBweU4KT;_ATr($dBk$>||I;UST0{0eZbr=Xf8^cRVX2I%1lPXy59uzUbOINSn< zCy89`8@oyqn9%nfL&^l04DzOsAs|1%ZJ_h2PtU<_Ly+XLPivBcw z^h=b?`a%6AYBV#A#h@9rxDsO06q8cgs=2>L06jGD4FirBK#0rFkzb(#oO zXohKUkBzM@3n`f7S_Et6!_iNzJIFYOH916 z_WflLQ{P?#w3x}qvPbguwVv@Na@+q6@PGHjA90Uf{*51dEz13x?EM4AhUOCn1*^ve zWs6DBg&~Sr53M0i6f>aPtGgbd*+mf%b`#={)!-z0?n;&Y&ny1!q4pOj;Cx9a`sW?=qdql?B z)s7_#&l=5ai8RFyHNWPgnK`EI)0YkPG8$-X?O2`yA@k&^M2m_mU^!V|+L~K~I0)GM;;lulsfQH5d1KvopO7@87Kad%# zjGco&v3Ktr3LS(bNwk`L?|jf)w12<>jh~GsaeH;@avG2Mz6OAs(Q<#_@*GH8`*j_k zTpU#T3`Q7vkbGnQ6r$8Jp1dO~odU4+W}oL}IBz~7myvw(H+ zC|2po30&ey{f+Gpz55+mft5IekUWdw&Gfy_#PyUnK5&iEg+NCbqLu!q zTMk|rw66Dpz)2}-5Rb;OHtvVp=1q;QV}BmvZFfZdw&FC1iQ{o|GYVn4) z^lvmEs52TzK?A}P7HJAP@Ln9w;p`Ujg#b6pio7L4s?sYErjjjh84{@XteLtP8fAgKjxOQrKa0l>wAGra04W)klbhMhjD;- z6lCgh4Nsb9_Uzeg%%j3@^&HZG(o_&3&&wQ@4Hz?p*j=I)volu1gE5)^!$&rCXE!o| zg@VoH$IEYdirvnOYfgmW0YEzCZ(X=>W^-~+u9qoS5x$cXcUP3|N>|r(b%b={xd2uK z4P{d408Sq_V3=beK3AltlE)?C8?*df-b~!HA?U{2=1hzga??DkV8FllW$78r7 zTRN87l?Xp_OWpclB_l__C#t9^mu}s-!2@y1XkM&ah~O49MYV6 zc@+(nd($2phojsw5ZD1-?I!!3dds26RJoiyX$B9r%yXkx`+LQd^>00z1W4aSCty$_1S}8&t-g0ldrwFU?)I4 zVX`B;Lsl=*wF+#N0yrFgFjA9in=m0B62d7A4rTHpOIL?F%C5SMMfA?vQpDiQTDSjlyfB&F^*rh0GWzVD^MKs^Y$gy*WqC~VDfbT z`#ZB9fAzKe$aV%VlZe>oG8>>wkABjnhhJoX+49n zh>|!Gx>6nm$A+%GF$-e9 z&m5)L5EYAoq}S1EeF?sQCa{5(%A&kF(D@86jghnC_6Yi3(@C|c-DPOm^;I&ICzdQ( z!i&091mjOrA9W;ZvkY-g1)F&Gtl$(oqzOlb0HdOVs@f%O73- zCryjRQ28A`dIlOsRK=i6;9a`@St2sTRVD>2-o70ZC}Y~aPG@Gx;vS+S>tin_mNkQY zWM}@`k~OpU56COxp!T>%w_NYdhTBB zHwU5EHBl4=^Q1Iegu5eF(+SFW7gB7Xzh6mXGd1_^()~6;^KejXvoT2jiFdh1btmEv1-9<%)f&{Vf5cWd z=m>!yX7C)VIe3!Bk^{I{n7Y~jB^09i;A&8_6HCGbuA8k7$>l?tL8yM1K+NWfE-qIC zwGIKE!FqI}Qd$QT|3o9?3hODl81n=kOSXaGAh!!m@WkP{8+;@bd0L7o$Bfsn=VYSw zs)B!}_Ko6Pdz)gp=-0D5Z_unfpNKell8qq-qo_!~Ygp5?+Y%KjYSQR^Z`NO_!1*4A zGsv-F5!&rQM;VKl>VJGo!^S=JO38bj0gjM0m$buRS~FjE=5iF}M;0mJSDNLI}NiiqHx_IHg z>jORqukbRY2(TE`Pt{r4gQE|7tuQc)3$n+mCuwp#2wupAm~wK2!Kzg~i@=h6jefgD zev;jwG}|9#vv1GR(u21`3#1YVflD)P=k4rL!)G(@-#Z|7JM7V;N9BR7^$kMyoLQ;v zoGvMx>theI*X={_1GTS=95MlMN|*Kxp^KBc)b+IXDXk&=DsZLbzU)U~sibxu)rkQf ztGkRF8j7#AX|YxvmtNVmOmw`zvj9emgR*{<9?c*Wu*>(yCF>rIyg9RHC!lx3aBe-E zW(c6(N?Gq{(RbM1%SHv6i+vc!zn!*dsazSjkJEETqLy!9Ze$BIxK9p}^Z*w-+my2?YIjJNg#n zG$_<~@N$o#Eg{$AR*dFl;y}(*2H%Hs%NDf{R-??_pD6hEl8?nQ$9I{cQ=vRhfCm&U zO-c-un4eKYk$kqXc)`CfFD*qn<$%ML+7a5ic=QPLF^%b{-I6;;Az@JHxY?`(p{p9r zw@efd57?nNy;3@m9}sGNP;U|=b$rvSCD}=-h5oj=JUBuWD%`U|;=OjJ$rrtByQJO8 z^=Iv~UaMzWS$ZWg9$P!|05*y_ct`06LWFHROT9m}P5Pwa@TR12*PO?neIAR%d|h)B?ylhf#NYg`Amg6>sS zYxOe7opT2k;klld?-`a=XwMnJ)!hHo;%ZW@JRas0(1pMCl%c3NeL8*F{{7QRHDHc_ z-Gvpoq9OiVx>!tVsjGd@wd(*EXAgXRwX^QtVFdg|=S*v=!Jh2haj9>6&_t?5{o2z> z`^}Y-`MRO}-k+mZdiW9&j!?NXfw(75j*1--TD~4EG+8ZRJl%z5-!nFbnv}*4FeGDH znrWU|F`m_N6FoC?xZ`$q?*{B7%0NkYKwh` z)M3SZ06#QcE2>|&zN=u#J!89r2aS4;DPF zKuJn85*;uJCxaZVt*rwqEniG!j@eIws+*e`2C91DOI>~1SkT&Y^c(x!y{xC* z5qq8kg8kpepMh9L6frCY@BRIgH&7WNWRujMWMbKY_(U%p(-_6|Bn`J`xavIF72Y!k?!%20UAh##9#hvhvkgAWn+x9HgENi(iszxN z4j20tEY}Paog?mQH!iOPY(-QCXr#+6qw9q!2oYH(j14jp3^c4)<0@k@5SFYW*V{q7 zlLs(zME&SD2X~Ddr5P>7_K@<@Y&W>3!}nE+&z!3`{HCXL?9TU{m+FDKr@DC2#B4N` zjBwxX4NLm=IYB#vI_{rr7AsejBDl%e835e9?`J$fCxaSbHN|A?HAN13&0aIB-jVXW zHyMjEaqAJHjXN9WB{E8y9TT4StwYxdwHSf2%n6n`L%F*0aO^aOf5fA6nfKv(aoL(H z6Ma(YR|sv=MBLl^#aCBvpe%rCGdr$NSZ>qedUz>9P_??luEZyI9%sLdHKu%YafPi4 zrXvbRutBsmv;?|fz20(Tt+{|IpafwM2xiA}8xR8q$s{L3TNu653wNxXDaT(+BbA{v z@*HZm=a7KSom6nS9;Jva!ah-jw;*3(%aRWKHRMW285Jiu0_qftSWJD zfw4jEG^E#&Pv#-p5E@)9Plvtj9)Ey>OZkMb9n6Rjbw~uW$=&U8IA>+Mq-N$YI{E4)cO*fqeTg7>JuHg`FsoV-q zNPUbJyMqIIV}K;Um#Kk=98r|v!A4$*GU{PWz~{oymS1kMeb{>vX`NtxIQ?~UN4u91 zWE#9mK?Lo1D757~*rYTlEalKQ>$xX+y5KQZ1I97kP`que&G36NAbOm4FM>Yf`RZcx`Y(b z=Kzqw3$9EQJ@bo)?R2GumkEO2nF*tZwJ0*VBNZ!!#` znnr(j&Pd9{9Fbs5!GzmOnGvQK&;IN1p$@g2QhV+PP*xTE1Rtt;LSx}tNnB|C;^j*| z0t{;{?wGAfMidqzgJ&Wqo*(-& z;LsDEJb5x3XefEz`sBJp(qUU-Yg zhwAFh*#XErEQGW9{Qv}g;tUV|idt_$+8j45%LhH$M>hki4@LjZi<-8@6`K%JwUq}a zob$%FFmfQu2t3pWN9` z&qT+>%&W;%Y?|fk>r2=ujd5vl-Pv;%yyG8RAwz``MVfK2*l$FWPJOI%y4%5zQX@l} z*SI)2m)^naq)^9sL*&T6o*0@v1=Fe^t&9%EYCFtkpz$*Z>I#GQm@@ew_N57Ul;1nR zm;T_v+@D>K{t1Pqo);!w`X=Zb5rPMawE_}hQXWSj0FzQJQ2dhFK0;RKB@hDqFCiE@ zMF3`2d*?#`YlDBIOAq8aO)vE?&HM>S18*1fqta z5<(>2$A=U5c_nrMH3YE_bOQp^4#Q7u1vb7bt&GBd0bon{77;1TLfo$!OvsG>Y<%6NYz;5~*hJe?GKlNnsI&+%It)T&L0e6a9J0z#J_)}y_CQ~Nja0oV`_+YJr zQb>wo(4B{mvg%gsC6=H>o#L?Z=qUr%WfdU8R7!V1Y#Ks`4fLMO8 z`XU`es^qsdL^)=PC}Es5ZOWc&>BN!c${hqbUmW~MeATK@-6eDMv99=3J{Xx@CT{|c zfn6Y<01VCh)+X#n4;^P48&(A3!V+{NRzT5!AoyKizy!4Z6;!_USwFfp2!~lU_5p@D za_i{mz!c8}LV}CF?ym@ekRSu7@)Wk z@h+4ffhm;P`4JoGBjd{~l4unT6DS4n!q|cL!T_;y_IXOEH^NXvh5$54bsJrgTjz`d z=%83~X^u$FGD0eK{C%N{r2-gVMrx$rnwm^U0rEF+Hn`vxMD}GoumRK+`Tz^*ynzEE zVZ&_N_WyjD#-9V8D(bHA&yV@{|9>r-sR~)kf4+(S`S$p!|1ZR;1lHL}g3`mVn|P@4 z4WdP<$5j|9p#i(NCjQDSZuo8*iZT5AghFqtH}y*%8qj22Jlnsl^~ic9d<| zY!L?hh%n$%xQy*g=Y(o&0Q6uOOUrtN4ZMB_E@vjIt6QI)hoTZrpoNM7qN|r+T)_zE z5tS-y+w2aOZ*5{>!QgTw6_o|1wziFb|K^X`Z+v8ae$MHd-?riJuewTh1^9P#IN=jG z`>+QOLIs1vGzcdUum&3Nhp)mli~qitx7>j%{#C4fJx?xk=&A3h;s?}@N}~NFtqujl zVj`p{nKYpRd5gFSz`>GW;Z&Y*5E3UM;wSeXxp2b9l{h$Q;Bd%b9;U(|cY=a+nSb6|+fUVhmyD4RPKVX-|6;G)(hyn;T<#~1In?{)aGq6!urux@nM zII~I@D4v%EPl>XYzC=VSgVEMjcKdApqkq@l28Sm*13J4S>)>C^25wvvO9?5dJL!V7J#u8>4qK&Q zTd{g1d2ZhmIUa+HT>pKs;s|+zXfM31`1DkE=5Mjzj&p%(G1RPPk1--r2V)r34}R2) ziZ@r=`Sd92@pw4t954}HekUjW4I*zx@BZ9#oIWzXyoIQMBH0tfhS29trMv>XFF+HZ z>}LBTdrfm+)v5-6&~}8mV$|6wT%{+Wf#Tsr877@DyLY!HtThuERb)E@ zII;ks*yoN8ZUi_XK!qlwgR2thXf0W@)Zv%DIZXDAz zz6*%SyCjANyrZ^CIwFmx`PT&C;t0z8M)Lqx-^e3e>+Bj0Or{7?v+*{C_Quu%#}gqWjd*H4Y!Wu{9Qt3cA~*-yEmRTC>fvF9%q;r_HwwR!)5H3%j*Nn&+Cyd zCt>gc*k4qWv;M4rvI9^l+4m@%<*(|;53kkyacvv6E0SIq23GFv9gopozC`LAO{W4I zHq}#1cD8*_<;M4hzmEOQf`($0?grY40LW9V8Z|ua85YB=S*2%|bEi)d=3;6jP)Z_a zDK^u%0!2L`RS0enNBu8yY3odon>zmwV?`-_x{nsm$T*A>L zw1rpei>Rv(N5No7lMBYTf?warrxz5-EidV6$Fp?o6DveWaP`k|8$}KS(>d!y5w^fc zN(*e}n%36~uM{yzc<6M8LQslW9>Jx1L44AI1e008$B$B!hk-Nrefc7^Bcy-@riFz^2Icbdm(1~&~! z1(S;AZz^s^a$2L;!yb=2avGGIExd~q{eolO1km~z-ZpOZ>cJYZD-j9D6X%r*)s`3y z-o&Dvf}5M9lh}5ePJIxdQXV6YYm&1gtj8pE$C5`sKz7m{3ParXgCE|3_Kt&~87aG> z_^sj44zd{F^>LS2tYJQ=C8DSMsJmzOik_9LOd7j%l4whf^YTeHjZK?<_L{t3p8(%} z2|KgSu-7N2vj(&QeXxJe&RI~=A7szoWa{{TUxCZ5m$nl{<1Oi5Z#ayNwpd^HbY-iS z+Ms;SR2~oI`LGl`#KhH(p`3y?FBGig#{pm&j(y9nLg-mlrP0)(jZcV68|fmU)I#A_ zqx2sPxS^a5BB*+HhXa8Pmb1alLkKTfdC|^LQVvjPE`}!YfQmH0pQMq})Cz$iq}je! zdH&(NQ)Zti?!usA>;jrOhtmd8EG)tjfKWc4J522q+{KXB{}nI{42T=1!^*c9XVK@c zUoZEJo`X(><(#qn1(Xd)b82A1L*4B%6Azvw9+j-B!QhVPD^P|MFYsl-E%ym)cIs$>3$I&}SF?vuaH=r3qf z)eD+j9B`~XXTm$aWU@@*(64PXoA?Y4oS!PNG}lo3vaf-G0?t7o#wb>sM#Wc8#XprR zR$u%5%R6RlETx93b3eWToeyb~514HmJftIa$?^1KnS)I(cLFi)jQ4NuMLH;3>WrmA z%n{r195uB0MM0V6AL6ojoIgDUdx4f8U$p&j&Dx;<9r|+3ncSO@Ep@uzBd`jYyb!Q3 z>4s17KD02oBqDlQ#-35Cf3 zW_@r#ix>>aG-Q3$Wsu_(^r#^K5`ayYdiwH|Wq(Y=x}1%>jOy(zSax7J|P(H7bBy z2MQ(m=nd46JO!ZDXxW+c)10*?`E4~i z>dzjFRo&QYvNcQY&9h}1N{5-5n~a*4hraq@L(8F6;!>Zx!@knBADkRdy}KH6)>K|N z@``AcLH4xZu&DBW{89C+tZ7rlmilYzZBbAtz1E&1ytR=(h>`24wjP>rzkTzXTvo%V zr7#q0gXc4DE=Xu{kku%ys~Vo3bPT6K@=ss)ugW3QHs`GKyX`i9!EN*Xw$gLHvbA+p zQ_HriE>K*rd!*8M>IGx`7QmBewgp%cO4xOD`K#6c+YLeXXbdJ3*H43b zhwzZf-5ZBLrD73SMYF@3P53l{GnR)qdgJd>7Ax zjpfVAEPuc9YU;b#o|DH)=9=dA=b%rIv*i1cmguCg-htc>WhJT7^s4%{h+viC43E}Z zDJZqB2c&-XQz<6#WI{wiUt7ls7`2?V^>36_^Z)qplp=1xzllzOE?I#<$V%irktX-@ zQgipG8YrKmel#6}rk{xe+Nsnt2QTm|lE^_ooP@#%duTJLaPveW@Syt{Q2`u?VoMj)OIAwCh_Fn zET09&ff|BAO0le{rEFU!)v?N6o4qMM(s zdaP69Gi+S@WL?g>dKR1>HH@5mKM?ib$YQ&j`r?5fbMm^qWaSfKQ;PiRke3+uf0~f?z2P8#PA=!-I~1iLRr`zcR8i` z&uPeYe81-}SZa?Q!FeQ?FagTnLe!`K*e^ref|OVc(Hz`PEQpSu{6IfIie$(YX%^gH z3=&K;Me9=CANbw#Eb7xggLxDCmRonn#pfOl&%Z5jPN{k+&eOv^jnf$juFMB&4_i0Q zku}dZ#YV%C0p7h3B=bD3$?kPnB#8MG0PupMu6^`6$V(=2MaRi=uHCEKcbv)klqfp}yL&-Fj- zi!&5`>W7X`OjqmH{@KiNNLh3ANzdBKUp1`e!f)Rv7AbRS%5pah8IF6)UT6Lci=PPK zxbcGhSuR^@kH4DFGI6ceYXL=VUv#97OK4^d3u&y_E?3=^p3*HrxaVtC z^JsyjhxhfG{=Ay;{+xZuQBeLj=gDv6KZSu6Of!&FzzSu!Mt?7*ybPL_8fA|IH5et> z2X+^r29)avy@UCeTLS$2)t;{*f7uOygBW|7{eo_KtS4ZC;sMFFLW46<>~Pk!F3>5x zq$UK9CA5v;`dERBZ+tth!S=?m?_%VF4KVqnjB zVUIgD1ltxTd2&tYAPi1Kqmw z2R1i%EeLczu zX^LCr``m5NY??wsi&mVFK*VE5#qD-=)c~C!1W76S=WS`bRjqS(+)T6p#Y!|BI>|kO0>naqzWwe#{6fOInI!EEikX)sr_T_{y@~dg z20zlQKRde{A4gD%q$E1#s%f7*amiKkJ(xurtqqxat?~y+%tpqP5gN!ged*!Kr=M_d z6Fv~dvxFgD3>9~1bo6w988hZ(!b}Efk_e`9pkF+BK8P>)ud7q{Nalx)R}Xf6w+j9E zAxok;%b4ZAIcT;Hn?wy`n8emA@T1<4!=AhM`Z-tk393 zoi$>1Xsdey+g>NTd4B>Nw>XL<9i>+`J|oI#e#&P}&BOand>TB`YG1l`@wTO(GhAck z^H;=`NI6DhV(8!Ull>rxY5;qKXHtDRr*}WfHRuFI0>3wZhA2^;2Ce9V(E$jq=(Nna zaXAePh8o}{s54stO;E43V5`D7>lTx^NogBmyHWHFVayjLeNd+dWr9*sGzOlEzyiGr zFwa|%ZZuFFToN^qR@t8NP98wQ=qICm3`X$QfddB)VM-2Vwh@+t@`lC0qWK*1J!*p} zBNje}K@GO(4xw2u7PPD$gqsHKYKx{+qG{kOKFB56V}eIP*MPt zE;QE4+fBrl+nV;s1it#j&LumVINi#`#I7l`tjXS?zPtv;Y+Fw*$G1z*m zbHPA2ujG~d@88)Phjru=%psGPIQ*}skF=&+4xM~X{t ziB4|bCb{jzMC`e!gxzbYFxq(DOu=BGo4Zfmmmr; z4DXiY&agV`LvKN`cQj9~MGM{56u{XJka-5vi6{(-G}F_@-u|Y>otuz@Q%izz0pCA9 z;>XIP^lOOmPNRRdYA8o%19i+0qEjEZw62DRgC>b#hp#>*u1XBaw8Im>+W~Q)d?iyQ zP(>_AWo!dMu-laGhNV64U02zc$Dc!W&5z3kh;uy5^hm6ESfkbnXj&Z+TvSB>5h0-x)$(J=Q(MY@eT2f? zm}V>`XZP#Zt#V3J;~r*as@>o1c$4t6{W7@^jcW@9gHw3lP5*w_`TB|-`%Kq(3aB_S zHzooGziK}VbX65tRx>^&|FM5yYGXk3B?<3{a`fD2ceh>FT&%Xma%s`mI?kNOVly`> z9{3ce*v;$pt5?^X?RBb?+1&$X#!f4yU1|LEd*j$~UHynjH&36Ou3auOeM0t7MvGY1(4J%)ak}l1(dks`=N-zgbuFbl;v%Z=>#~?bBz^ z3xKV|7L<(1AH~zIzP8!UCXIYuxzny1o-vRu$}~*Bn|WMkf{>+uQ{Rc- zFEplziGNAiSc?*trW|y3Zc{A2RMyV?Y~+yc9ESoF|5YHyN-fWA@oEG6TJBCh6^Sp; z8V%nTxd1tH+}-=F6{JzVHcOTOKW=OY*czP>2OvJYo&jw)t>eXD}L{@ z73$k2nxov0_wg)mSnb2s4&|bdaHjdKq@pRy^Tv$kA#DNq(TN*r&PAHLIPJVRrimpz zHafUudT&d6yyBNhl2^tvH<{>OJ$vI(OXK1ZQ4M7e&Y3;i(qc1{IVNm;4oiB!xym-l zIYH;6d|%^8pS|<_!I;Mti*MniKf5JBv#tESyX-n8eaG{>_dj0QDSyC!msI10eJH{C zHQyM_s_4FYfhrdbPmj_mK4Ip6N$_<>=#-I5#-I!1jF@QVhnd zg$)10=soonWf_j>uDMy)$ERfe?WYr)wfWOMLbL85FM*M~BJMBskk^{4LFeN5#6Ww{ zmh(-2E_<1$iuDVp2Q!IZH>-zloz%*0*#FU>+l2Ru7~8_M3r1%r-)FAKs6IbwTyRg) z?m?#9*;|D(9VOacl*#9m{|tV1Ec52ZSx@A2&yDa^wOxvwnRirk*6h)|S8BGt%Fhg) zy1%)%Hzi3}%&5+oE>~TZ@l=s*)9~#f2@a#GDe~)m1zugdrJg%w%$94bpxxxGseuC( zwclSa&o5zfwOoXi%}gnq(ApK)9uoIs!Io0jpR0TK>Uqnmm8FPie7aUNq$g?T;$iFR z&aUD3+h_D>E_d&U_Ewo8nZlsAbrs95%;&i6a^$|bBU?$e{Q9tAyO`+-LMd;9&U9?4 zX_r&HXIwb+dA^ZGV{#AQrx2ciheda$h$UQiTbyL)YGjTh6Xow2mZ8%2CDz-+-#m1- zh`f-atULtK|5{H0Bm4Zj%j&v30?vO@t~4E0e-^6nW%eOOpN9QLcKLhDm?`|9M!yz4 zK0I(e{NrV@x9XF71%PRC?-5a!JCzfkZfv_taNo;m;ZcQ?nW^Tzm)Xq~43cX1Dv1@Z zj@VG>T6-yy$Mte-jG)1Mp&xo>9x7=v%T^B-$-P>4Y?X=RVc$fFOSaFFCyTylkXP&G z4b7R;cBEZCbMR=w9)lRAIs#*aGN$yvkWW7^r_LigVW zr8aRV?iFDa%pOaeRz2IPqX) z_fMGM{Uhg$Gki9G`EHu!ta-?vGx99Qfj?}Lz8vbAK8rff?qz#+jM=uYE@06xOL3>} zl>9FXkA5$kmj6RPa4+I9C5_`5h0LPP-x@~eV^(f@Ep|n0DVC>Hl0~rGdjH$$Sq-@g zRr>{3t`60GBq#ZEA}UmDnH~4^-nm#+n=IFbzW5x=V#MQ{X1sbT+BR>YU~YEsIg5zt zw~tv#56?2S(;FC>WS7>%c&vQh&Neer*Yoewt{eKbFC^mN?-ss$S`YR1+?-dE{6Qr6 zW=KJ?(&Q_1)w6~dtxLUk{B2o@Z_m8-VWFC{#3zayql=3VO$0CjZ(?43ah=JdCoN*U z^3%u37N6YBPAR24_O%?pvqkNrY)WG8`)gI_AD>jZsd~rB-Mm%N*o4F7=#yaSYu}gj zjqc94wzICm=VjNhXPfPjFAEmLuGdP>=<&*HGAf%eX-s2N@t2gFt2Te;Mo8S?mSQH4bF(LD;4ZWnaYtNqCSmz>#$8c zJo1yake-IbQ)dH}%+82enWsia8D`g^nWy`o6>Xi+yPWW`!{S!tD#dv}+(-WC3EZ(t z8WdXPu`FC+LQ1vxxxhbxY)aDXXWDj7ntCQEG@|dQRfbAIX5X`*YkPyDHnK;gJ+#w# zT3;>{RCZjY?5t|whP15i4=7Bfm-HrpZ6gLlw zN;=-F<{jCx+vZO7+ZvCq^7Q1?l0zEL1&@EPnaUx(gDJRYSn0ewa^V578}j-6HIjc6xAd_<@9)h9-M60jOP}m27P;RQpl&cH zzp_n4vdbgm_vr7p--=W&R z_svL4MP~?0JU}4w!ln2+^;;8*E-q4DTkFqeHfM{Tx}Dy5#Y)K*!|fW@{o8*U(pFnsSf16^=vTvHg{mO?AD87to$Nt%P%cT<;Eq)3|Oi%nH&~Es$8qi&nnUw zRXWt<*&*8!Rr59~lzR!wC|LTK50`#GkHd-24>-9O&pq{x?$?}%<`^?CVcR|JON z*d?DC=P0=0=5O;vP0swgTk2c3x(Jobu?-DbX-Ry+TG?g}`2}VR#qoZG<@6$7n`P^Z-^Z)J{2ijwr*J`{H^YC$#T6YYm0lHzfhz^gt2QaBiNF4)S>weVd-B<2cG~f8 z@0=&q-_x<)S-gn9eA{GsJ8Of=tGD8x_-9wn9o`lALY94IyZuQ==Q6Fw7BTmnm`F|Y z`2EW!1qI~t7aY*dkIjgz9(0p<{UPM_YL4&29N)%c>vMGyHnG=D_j_*Ko0HCUgukoL zLZs}V_+Y>jwZ_9H4Ncr`H=i2hZOm#N4^gmJ4z|z{WOy3pPv15}i_x%R6N$R88ZYPX znW;3@>fm{g;Z4V8Am)N8@@H~?a9(x|`s5ZU^-H&5M)#6fdG4Wpoxn%WEeCJw9zEY| z>9s^zo`0)`l)gexbdRNsvp;{yQ?}oi_to86J|H?C|2S(kf1CD@-f+38xWYezQCs{H zEyIiXmAfvi{qjx2Ytap}(GG4!w=`B%*H+_}MLsvidWuh!u9>@e<-8Kxi?6?b{zr~w zfVigYGRqaxr#5A9RtC)bS>{mI`o1;rr=(nZR#(L%d!gPFtQ zI&wFb{v5r_%v>9%KB}bt^VC?3(ap0*I6nNc@(=q^xFOmt@qpVAPIKXh<9P-p>3<^k zrRjM&O=I)O4QVLfoTD*)_jl&RvOT&NuXW$gHLF{1|C&ctrM37IYwqrlm`FQgo`koy zeqk{gI|I%CR0u`xPh)H!NwZDbbV%@P{prx$ez!>LwZY$pj1@Pg2CefbI})?3g8zw{ zPwIV*_6aL*@{e<|~ntiR0mqD4pEXb;Cs ze*yfyMls*2cS_G$bVIH+`{u;J?eL3AVggIgN$>_U;VVN_c8j~7e*I5jkx+=>&c^or zQ_pPq6~SgPi9O)WBfF*Ze555a-kf7?lSy!xyTJR9NMy``SF<@7b4B%^Zasc+@Aha4 zn+XniJ-m6V^=&NchoH^Z$#QZJ?TYVnNo~BR`u31Bx8RLPBZtv|Q_W%#y6T;Ae>#?3 zR;~?+A3iS=duq(_oazTfr4a0n6*U(3pB&Kh;Jk5RuU36<_%0K7qmzMED>|~7hVYV9 zCJ-4%D_`=eTSOQ-^?SA*xq7)y9Sj(>0tJ)nAg8N14r zOR5E88LZydW+l~=3U_2Y>Z-7ubnIow*gV}$GAA$S$dx@6U~s~@9K1h2;(@8ou(%L=ITJ}0pus=wRwbeN7+TCLPlALYbd11SvBqguvu=;hH z^pYxc^J6wf*1vc&t8Z;`$+S&pQXgJ;o|mzb^Xzd>CAPEnuXi2$C#zK~jAht6@3?@W zEZ<@_-NAkN?{{CcIx|%wNGvgM&X%a0yaYwIywfg?wfDEH2ye<*o}rrOGs1q=I!IQQ z{r6Pe(n&V9*3G}7E((+sii$f3+~IOa6O-&-tH&3X|v zV?!-JBhWAVJBL}b{HK-TNTSQua`}`(D7J|N6T0^vagw2 zy>Y>Z4vn4fkLL!d49;fqg$E^FpOqi3d*J@VI=J_FT4bk_zPlao?_V=^OnE-Uyb3bR9Rdv zjXMffGe)1RhqN}H%$4=@_VfSQS1Mwo=dm`aqtHM9y7G+ECWeN(acWVE|QZu8I$M{!QO`9vGyZ&qnlkd8D%}s~c(VdeI zXNnxM(Pnhd-NohKc(eEto04AbXZy0OnG4uFUNbkAO-OlKJ-E_8^>)k^9#avvzSL^h zbt?a=a6|i2HcEZ%ZoNCpJcUh}Jz|_EIaoDf&`(f$$By6HC*d@eX~9L_vXq zM>eM&g$}0Gm2@sFzNP*3U8et}k)(^?Sp|{d%~3_d*<-dxeQSC3l#j3UYO&Cr{K?g+ zKPXMUE%X^~-PZnf3-lXWAKRN(2{|n}zVI45Bk<PfXPCr1e=0y18 zET!|cUi*(Y#V_sE|I-uDz_ZqR=c1}lIwI#9s;(_+?x~P>;I9~PpEEl7T=`}FF1fB> z?lENpTR7wf|I{87V)uCSJ51Lst8+p8x;n0WSN&7D?GE*`U)&6~Kb5n7eDIR6z&3{Q zJ2gk1Q^Is-JW`V)imqRJpPYSv3XaGf{%5L`29BM-uUN6}XfjVt;EV)IpG}}t%`5EDl)!@-e0#>iJUDKor&-oa^s4-{ozdJIpHR`d*5kg# z9l_-}BFUTbh8|1`dTATjkh&>j9oKWAg7=M0&4**7Dz(!aHr2d69vsf7ur8-Au)X(w z5+hgr>8QewK-I3M#CwDX&kuEKGcwo|Z*QuU=g+*&JjBo>;=yKB(O}a` zGstku@raZX>RYq4T@OmqlZ%T?iZ(qV3Z^N$bf@joCng!#oIII+%7l{9W!;6 z9Z284-8--?868kCx~IZXL>CY>wAQI_a{EC;S$t&q6=sy1)*#w%nVh(1E|R?(p-nh?VHbmf=IZnc3mj)^=DA77%}Ladn7ug z=QdFQGYEogO|CI3OV6x}>a+74oNzI{`i4gRM2*RW6J}fR{bmJs0mt?D%#F?315=l9 z0*_@k*`ZmF5Puy+-=|5^KyDrDXp*VUqvn8YLF`S@vS%?8q z(xRH>;71U}kqN)y`8Rn~WF>2YgdW0P^K0~{*w-#ElDw{^hc+7)+D*yW@M^|fi(;Ze zYbT3_&#QI~oc4HeHTqc0Ivzc-qt`|te>(OC8~g)_u#*qFAr$AQFqNt_6KDOye2?sAFN_{i1KNm!CrUC_&)!ur?&IkTAzvj&`}ccOF3^S8%{>%8XIB5Xp}JAdCi9r!?8+sz zH%_(VXSYju@eJ?papgvGVMQg3QLc$g`^wwXkFDz z#rrsW;2aOLxja$+EZ=*Y=M$saNe6YFbLXF`r~IZG9Fl(gf^)z2O<~#A4kLmd*^ibp z)aYIuKS)qe-c=`k9$M3ac`kW#CFyQOV7jat;wLdB@j19lU2g0e;mDG@@mnk1m#u65 ziPX2>+K3Xh(P4@ie~lb0pGeA6XXC~*ubGsw)!Q31aTZq%V%5S%u(!R4nAfh-L<*iG z`#o1z863n>$`?F{mmxmV*)QMLb$LVTu9A+Q_n%65APZ@7KEGQ$3bJ6?dY&oGyubA% zZ@nKpa8u#Jh8~Y~S632El$zwbl z&My1psAz0k>v(5IS~jN)lPM@yW;br>%h3PWeh?bBWWaF}AtO%WIrME3(uQ{UpfdYm zP3h6&Y5_;Kvea?BFvFAYNi9mEy@I@(qiO#{C^W962Q$0_w(8Hh)Cg`L5S)U1;R+C7 z0#X}5yMp==@X-*{h+k>?sCerZ3-~8I12VbT*jkF`Y758_O}}%Bhd@qom43?%c)g3w1M6yqh z%fa;q`uS^_`|Zb^JNzc9MGE)l@+yP<#b4LuL|QXb&bi5Ye||BnO*o8^MZ7-8hp1|kJ{S8vZ1nbZ&uwU@mF7*2)VG6q7{=iE=>L6WAO;Kh1-U%nwICS znLWPuk+=wElU99g|MT4k?4)sHcLy8wisTvFmB-hnF4cLo6(u;~ZhlPN{95C((k4$_ zNQ9`<8{rRcO)-cIBLMeSO6A9>_1n7;@UK@8z=U!3|8AK^xi zp?sw~@+6Cjr#M#oEJ_Ro1gR;me<^6LY=Gw(RW9bt(#9u#rJzfW~psUhw5d z_qk*@*w=bLdGaI;B<5eMozAED?A^H(cO$s;(<8nskj4lvr4>{keURz{9$3tv?&fX$ z%|eVXAH9Dvzz?+bmcW@K!)@8NE(VG+Dn$ojh~DzAO&uwk$+WuA%dHyZ^DnWg#HkYc zsE6r7R`7C9E_`5nZm_D)SaHUEopGY+aYJsmvf0N@E`qzdH!T#|OyXQJ#cGJk>dTUY z+HjaX^1RG7U~X*DZV*Lel4=RcUfOoL$@Jk4mQuleK#o{0H#>jo{h-~_>w@7FJkden z$8PJ<=hslqJYZCoNxZHmQ`n#OZQvn+Rx-gx*L54yCwpnX`bl_ldID^onXnJ?3^3y< zQi^^1+IwiN^Tzu58nH+1+rl#DbNQ#^g_~}fQ_)!N#U3YGn}@DqThc4Hy^%Pilsh9C z6165)lk=qTw%m8dq#POoWR@AQY)p6LvyVA^N zhjdm2DZdPRHOO$$;2n@&KT%EawcG_upFZ#`!N!vG1ha1ru$=TCGW9&gTmKVVHAUhn zXzf5x`EB7hbR19`S|^nY2G59Jd)(p&iSU>iXKlzq^kci80Ku!aDIWMh!9|ndP?3e# zX5Tf?0V6-Ao)yP$L$zVy;K-Lz(nAvYEU;1MF1Twgfd>Ksa)MCUf`AQw#?!a}CFBF^ zs$2?@kjOnZS`RX*YY*x8LH=keD-UY|-5#Ik1HRHb7 z?o6*7UH0*e)#bL$UvDVNJcw@wHyd9*BV+2(=Ny|9JEjuEQGEOoTE^OK-r8HC_cND- zM4KXK8a>F2^`)x$;x+T*`*VYW8qDoC)uX4fqpl~i7-1olcL=@-TvLnH(LLPf#Lb)u zubNC@9`9PE#+`Qc8gP9ac(82%Yl*tXCx!gfvUgJ-i}7k1QcPnOjuxP*r;_1OQ!cv!>CFJ-^b#hfpnes?G4+9zdg-Lr~E$zLdVdI+s>jNxIPiGk8kCmkAR3>@V zVo_*3hz9$tqpY)KHF0I4eJBvJ-|a8?V1L4%Ws{_qE>efLVH9sKYk9azR&d3+t!tc^9EZd)`J?1faYt&_-#;ah zu`16-6&)%^VRO_tzvj6ZH0a+r=u1}Oz=NyBgMI8pE!r@qccV-OD%DzS?)xJC*Sv7A zAxP)Il125-orF)y#800{OnF6*b}7cBsg}hGMl&^&9lP&F=&5dOK38=pDT$A|{AJ1R z1a^EhK#;NpsX;$FnY)~*{Cl?#>0UoF>GD!h>l5~+$Zvhd)T0u};laI`kQ!8D!|9+_ zpinrY%_d8dX~KSsQ|ryPI^#ZIPfc&BXYFoEib)svy+3GtTDyH3spy^&tQ6v1U?lmHsKbU2$JB8@+Ym^J_kK_<9vqakip51eq{nLuy#=0_(y4)MB@{Mpp@N8>eY@<&hH z(oc4#FRmxzB)@AWS6PU4&7chr_izt)x=IVy*ycL73A_j%gwI-wrL?nSr`7pV1bi_P zpvQ4*c)ZNq^jmw)z}8Awb+)N(Q3A&5ie{Pzrm%BEm)*N-8cG;mMn3`T0&A(bbFrKy zCZ@Bw+k8UHrv?|c;>-Gj8v`3Z{<<%oD&az^ZF|M=6TL3=-ftpZ`o|os@!BAE%LrA$ z)Z&Ln;=BgA%C;F#m?uzA&cv`X19=>Jtg9ee zImw-82(C;(mVNWJMjz*S0Fz8+UmDQ(DmC+Ui|aRt(9~8?DE*PFqdqD?`kR8O{X?4U zbEVAf6Q7G@jzC*c`4? zThRwLD|^2)loQ?MoY-#73`NSv`{aLak2AX|_j5zG(mYeNkf$v8=!bMmn*{tc^L!^> zL?+smtaIAMG(}^*a5&uWmvcVC6f_t_bE{@l+v!f~wf!^LCEvV$4eeJ3*OC;+u3TWC z?sxGTC@p+wdcrZ`xk7Lgd2NIPLR8`TNpV{4Y@MBu=v5n3qOb9RCGbmtc8%$9V`yFk z3p;kF>$lzj+5l^QP`zHFB$D{mEm#ei1En+CSZG!-nuv+EAn4S*u&@9fiCoRxty3y6 z>rlZ!#7kx`L_>_GexMdZMkRfV(bu?i@IaYoy6gelU{IgGaqg+&bM}}VD*l9>oNbZp zX3=8e+e{3L7GKW18*Ue^nVcW_hJLrk;XbUVkPgWEgu|!D+o(}0*}mLXT9lG)6hgk! ztUYCV32?w4{ic}v7vu{mv_f$b)1_?kN_jd4`Gh>a-V}79=a)DW^tv@os+O|s+e6%g zL%fGOO{NDCh_;e4b5jy*pJ0W`Ui(ZuXVJOsDrw6@D}Ks^ZCR&78trzUfR;cuN2)nD zx#S^IZQ{a!Z$)gJ*^7MbiH()b4FVpL#5+OW<^?440~+lkG}dQasEj}O^vyN~9#1!&s79z_W$HZAD9 z8jVPsXqUt6%Zxx@{uZZ^+f_Q1l@PPC=}WlV;cRib#PbX=SoZJxU5~LpJjYgZwhdpB zh~6%L@J#;cCb@XefsNx@f-X-^p4CcgikB=e-qC11^k7gUYh9F?a2CSe_Qh|>Rm6~b zGqYbNlV`EGeWlyE-2dd>GY+-FfFC0oNPMb)scf+`<8#dEznw6Zfya#mA|*~Hv%^%e|M4> zW!HjrpL^+yPS<&G)xDZYZu6k!<;2Yk&s&=m?|`uY2r7vRtF~Q0kje^buGx|m$QQCs za?YblsBxbS%RZ$y@Ou1?_YWeK4wkiO*5*gSh0@j{{U#j%6+hfDI?He}R;kVt`vy#a$5N!!_6lyOn#x_OPr!c&!#!)Y=sl8()|AnSeSr zH@+#YtHes-UB5VY)qquM&n%Y=S=39u%hIjCig%dj1&klqdtVSTx`6Hkk~Cg3#WKH+ zyk`)1P^bJFA1PZpmb=4!Xmt6^9UR>s{ZThD3(*s{s^#lcgdeQ%&Rliz!ow}cLed5149Ly7LioJ_D z_AOT}A5vRMmvu`&4)^H3Jlfpb>nzLGQ;lpzf3Ox;AK%H}+p|@FU~Pu`hK^|FshNNM zGI0OQb_q7ZP9liq;eEf_p8`Cub6qQ~6&FntKBnu(JnN&{fFQM04;qxhe z=2BM;iAS3Ukd&DIBK)NX(wvq&*lu$>_d8aEJ{%ui?KM%(Ea`Hz0|6?$U$gM&&}3)tL8fJ@$Yz{V~C?Y`79eP97?Z;dWYfYu9*FFydOlA^9Id%j^SfT18|EbU#$xdGy_%07SMyFe3$T9qt)$*ndnCY-c|01TJW7 zJUND@z^_)cFQzVT}fog6SpEFcOUQU|4?gRHcaahF~ED{`Ym>k>R zQmmab*9+2_STso8^{Zed4rd6+BEmy*j}C_^ZeL;^>^nfR$9WtVma@_a=STF*)bYw0 z{9R{`i89A`*sa96UgFDtvt;;4g3_HS_ts zyIj9Uy{HsO*y!?=$GF~HX!Uz+HX85pCbb+7JM1Py&=Et|Sct$FH|{+e(vS}m8gn9d z$n&!pxt<9q)_BU+hVr{2#HEMHrZmtl$n1X)=KIpfs5wq3D_+c^gjtFL#2&Z4<6=;R9? zMAMZXomK+%@ORa&D`z3S;Th~qs6oX_bv1$xO|v);4#!|y%DpWDB34`k;;&jPyaBxZ znR#ob^q^|t`)<=>!$ECmzi;1yKR14`3Z+fO&C z$9+ExjCDzghtWn%lj71kKk~umPoqF&ydToW<_u9c8O^ed62t+CkT|?+h~U#ytUsR+ zEg=GpV?uRgQd$Gi%rG~CJOGSq)Vq0wbaN*nh=%S2Y-08CcH?}B=sD}+uz_}fgRtsc zdbnb?{ho5{s@nu7*`SB)Z0kfqV`Qp{--okdF?Wg$u5{>NRs%(sHZj$Y(M{-2@y4?&6ESRW) zyCiS_uG(fxYWc2K1Xm8(PjBuLx4n*(6L)79NL5l;J$!Cn$(h4NQ+}_fgH}tMKt`9J zW9M01M!LfhE+X@~$M4ZvPWFhI$F0k73y6GVbx~yM5>P5<`$g%OMdV+aD3hX}$Dk0N zZJP&2aL=YwtOD)@pWB{^o)dSa;wE#%YP%O6r6;bip-i-9A7ujB0ZD zm~XJ{+siE|^baT>j-AUow<@^(2Y~+j4B;$RGo+q`qSXn` zrbH1ZmP&_-2pF2t!FknP46pvzCvg;RL6g%!SvOVMzA;K;(&$AC?s@HCDlH*E?UqNA ziW}`))+Dam;&*uxvnV~_#3dXs8R{YYio!KxJy|Zxz{U9wjq$Q$(@Qd%@Nr> z7`N#ZDewJNvr#7bEoB?U<}Gb{qZ6`apG|~?=ohsZxe`JI$ZB?)BGSU++rIjXPsna( za=wy^Y&JQ5%ybqDsdeR*>DS(x*EYFHRETX7td}-?bOpB~?0YJ{JW;Ok4i z@19alI%}C6!FR2~ZE5$0vaU+a^p8qvHl}rXLJ#k?P40FXqCA6HOL7k!&p(D!yTMPM z;L?P*X+}N^3wsX}XXv*%Kr-Gz9*CafGcW|~tGOZeDIb#dKx?J8j`IxlK-*jzA_e#j zltATgK*oXY?S6hSau5welL-7R?ZIoHxfcF+ammzjOT(S2H6;Vk^?^{VbAcaC;6UvQ z)FjrQ{JJ=E@!(>&Dplm4O~mvypwY;t>QQLCt~>RyE#V~h(~I-><&*VvwY7r{A64x# z^VjpMM90^8QzqR&?&|XUkL5mPnRf9Z>IV1VN5+` z3>V?k3k_=ga7j(I?;f#JD|0@V(4N9`C2!x}=n zB_$=lYkrC$m}P|tkxAY>4{y(8UjM^~CZ?WxLWJ+iJeJbw)o9V$S4{HvbJz(DL~#+a zpL1L^ODHo2}kZ$j(f9B zh9b}SUz1Ymkqmd{$6xN!0MZj~<{4-SO*hwa(}yi#X?L%K2%ZJFJJRVly`zUy7Bi)1 zcP@u9PWmP_29n9!oiWg?`1(CORU^gt)Q2rg)pA!VATaR9h-@=Z@==}1m|;x ze5+oEb_9|aGDK|~6EsI$=WY4N+u*RyB_B$BSXy*9n4+i{wB=dgxwGl}_kXD(~AEpr#F>axg8A#eKOcIdmk z2V6YH_C?yHL01o>N7s*|B;9H6N8XfE-7!Fje2NRh_-8n&G9{|$1DFNZ^u3k8iE9;} z#oG|GDAKAOFc8#D7*Omw%eXy{KkC_xbU(V z(gqk18x-2Lv+M&mewab);C`BB)#I~bIok$T5>t7cj`cswI-*wKxs$>b?nwe3$IX zrsdIFUmf`Lib#>DnQ?`m`4>S7VbW}@z(`R}4pdWg(0f6QnO^ldqj2Byl*nwNzGfVV zn?pF_Q8ZKA`$C>05EXYsHXW^k&hK9IF0h3K75uOXvB%BH9u+N$r#}-k;+uDb^;Ydg zR=u9rgsuI(uEi|IR}mWUFb3S-J%8`^l|qd3^CP)J;|Wa)YCn}!|L<%KWP2;#z|}zz zninMgHHKOas(mw|U2;N-m2z#>&*1{$jI>CzRJ8~9y13H@*+j~^9V?Gc_b*&a#$`Q71?FGFg98VZgVa)w?J zt$R8YR|AcI=8=pkzB`ud1gA;(6~3U)L8UfEK;XebI)91NmAzB`5OoT8MDe0euC~w( z=WEx!r0>kH<>}J!aQ@R9HY~`vIETykD5bJs%1cG8tW4+hZo(yFrER zU*1Kz4V*%7>3&nRa57ZHCb-JQZQ@KA`|aI>;KZs4Ws#dOyZQa9ShQ|Ey*yuxSlQ6Xm*KJ zM5`|#-0$QK%&d2D(e@F30^=_I2s3Az_dMry4~yOAWI?8`B`5u}oq{i!61Sj^WgcaJ zD!ZasQNhirxf8p05XfGL`^K@|wJC#0PPIoZZDpN2tG@XetP^l~S8#xfP9l$If{(r3 zjc>kdlsLM6Q>`Z+Gb7|795PD-bMBvW@sm!y+7pbgh?R)I_v=T>M0roY6&;sL>Sm?B zwQdBz|L!97mfVaox|xLSSpug;H(rCMH>A_6OHK#e<|r}6P7^Eda}#v!%>BkDC`lGZ z(Q4QFXEW?`$1JZiKc94Z$COJ8zl@0qsK?CsMgkLn4a>Q>TO6%NP9NU~0!i%A?`f@X z8ggHZ%TH~3%mm|U@d%x}h-B1gx2wlhyvlbWncR`l3`8971z9?=-B#8t8`h&;^T0J| z^?h7p$Au_AU56%`#c|X~ox#XIEl!?%V+;pID=Y-aB=Uv*y)U&cIG^}~R03A%gf0Rs znxi==^30itN^mn{=iIb&&12z^RjA0OP6qV`y-&j6HQ%?~bs1`@yzE@(+lvKVLw{Uk z6Ab5fwbog~2ylcQZ1}+Kf{mzLHUhR~g^{7Ax%F+&oo;8vIk9=v`M^L-4e7fL0(^Bc>2;uK^M+~wr%>o5nA}#T~K2;F83}`5Bj(e`j zlHo#}u%zdW_=P&}eb-T`PmPefI6toB>3`Lt!?FAFj2y?c2?G>EvOvEAe&0E@bY)wCbO;E@&^+0VqW2Vif?0M#}Y$ zRc`g2mbZJSE!@&|+T~5?t5U7pdM-sXpOv$gdrUuj+^~>*ttIo@M^>f#fSbnMGZiDCr4o^j;UYob=q?Q$X6AYc9Mt)82{a`_yyv=lBX zEJ{&>MZa-IN(o@sm4$cZVSSosn3ORh?{WX7wOHQ1d6;*uN_(m8Z872-7ugKirmX{Y zpPbXA#KNp>ODATC$&@cp8Qn4Rq(BB2PVuD}uM8YVj(-Fb@zAOXpL38u)@xUhJuVIQ zbl{-LQ_={+_nFjJk)<{wH=@!6m+T`^PsBegv4>*xJZ6ny5VVs!Xos#&EWu1S_qgQi$(%l%Rjga1LlDRh8Q0r^y7rmljto#iiKsoZ(WJb&WN zw+FpdJ`ADHo~oOEOV(2rZOSU!ct>*kO8vZ+LHe}O`=4!4-|2s;ENuuC1U{SU@sLNXqg)Y)-U8@^PdfO z-{`p5m!WtPb-e)R%515tuT3z2An%~_uH(UQ&j9(!Zx~uEhlZpy)TvaL=MC)i6=o&8f4ST z)zU_E#*5n7fvTnBn|1&KX+C(cLJY+G zLopzIg=8!lGU_W=bk0yM&3bZ3mk$%y7?T8M*r$su-7UyHeA7;P;)E=Cwl}b|jyH8XX0UX?VYp5%7rz-`{#nl35et>I(hZI;I1c7b z*mnheb>1e7iONY0fCaUtd7J>Fs_QS}Sa#r50Mi!Z<0YTe)v))hZsm1m z=yz~YOcZ@^0fIauJXdRSy4xX+v<;kXqhG#!mYWNJ)d&?hu{ckvcoRV07+jh)uzEvH z;?A6z-LSN^vO@ffB5odW3ksGS7uSJ54k80r!9ZjU16oCsInqj}OWl^c4gK3k=Ooh$u*c6c zq(fw!?l6M-H)rbeu@^y$J%Y@ydADUZP_0##bxuFJrB6S7p&kz*V`9EN#px&2wV*%q zbP(a9ooT!3-Q|71y)1vG%lU2mC_dudWy|4JLXK(5yIyPA@UY!=f40QN!8`Dv5r=6JF?ej|&O5>)5MqAA9_WAA3Gd zR@&l>RB;E-bIs^lYFKe}KBQybIiuYUPMPnUIe0#@Y+HBx?eID@Y@GogS10gK6-cND zN7!cYV(*15j>5HT@&$uvI0#CCy_o8SElc*Z?;0bOM;Hmcn1~7L&<>WhUtKSj7wE6m z@^W@WF={qF`o+<2U6jnYo3p%ZKmV}Sxv4Q^H8wffu-SnyiF1;VHf4aK2kc07z*E4vlm-k@xcPh_sio2; z`BNjX-{v+Y(C{)ap_G%8Ln+v3atiDa13xVR^8j_-0@qP;Fv|znezr62NeZ?@GLXuFnp;j!TcHL;fL*@>XCQvy zcwz5L?oK~Pvo1;+!TSa83pcp;g0~GktGK}Tz(Y{)KL)R#vZ(2c`*id*x;92>A3BJj zXLNq>0NwJU@Kw|Ck$8&lkp~_86}XPMe9+o4+h)^Qizm$r;UJ5TVl}(PUNAx&diPp- z^;>GWPp7WGaO}ETjOo;fjp+Z_m+hKrnwcoE~6-U*98` zq=KGdcEX!P9b-yG(_-nfYdV};|kS2!1fl+*%P4m1rwucNRa?ufXzwp8HKD}*zg3S*+JkU z{{f08RRf&wiLjDm5j2*C_f_mg?q6sozUZ9(>XUUs{18pP{-DbLPI>94&CNA5dZx@Q^M99SZ1RAcZh+pA+IfdCGf zMyS&wY$Wm8?oBk(BLGrHVvj(cH`(lw=C0fwB;1XQpV!LOB$PgM_(=$q_p&_t%iuKu|o>87}?r z4@Yw~GNZte8C}o=L!;_VlhYJFBqaGmFKjf+f}qA`!0+O41P|X?bI6NCBWfWi!1$Q8 z{(2_B98jXj$*JDFBcS6ouJ(g%9-2!9Zc=!aO^4f~B$1pt`A^_mzhVdLe%TNQA5Xjp z3lm#_4a~*MYIRMO1$FYuwS zgAIjQ*2tK;XmMda`yHy!wnM{}*}}9ZyiXi6{(!S6{c~|aRGWIRqI{AQNdW7J?nxWT z*mI^W;i_x3Z;c&f2453a`_OC@T#NEJe@gR&1JSduJAwChZI&&2*Vnw0ZdVvykHLN3 zp}pumw$4ycy%by50#Xt*=b;|nexjI5{=Gx!qSCcF`f6S7LP?M<7(_E>A*_%L7;#U) z9!`cQ!PDW#{qIqexBn+C)=oc1L>qn;Wl=%r0ffwyRz)+guwMtq>>tp8VLpCT1eZe4 zCY+Qd!sH7^nKhFc@R<1(Bxk!Yqni*zPRJ*r(!O3j|or0>PHA-P8<=?IMOiw;*3TTZn6of zb(`pQx59*tO1xx~wfoEluI74F~XF@3bahDH%VQ5{do=mJ?gJXs0~AM#473PUqOz!A{`3IUCv z4KeI=7fQT)nc9zTKnMitUybhU2pGM{{;|d+MXhtdp6gG22m7|cnYoAU{K*1F*b$+A zWIf5x>siFb&q1KkS-*=Myp5HZ3qrUVV#EU8_*OQL?^OthQ$>4@Uj^!EHc`kgDE{jG zLpr^~g9zRJ^oTTxc-`61_*RoEydHIxd~fukD^9&oqbhmfNlh065~z*~dlZydNx%z}agmE(^l*9wiJU0xjq<0yt# zt)3@DyPX>nRYZz0$WkB&D*MDJ#$MEuf6e>5N)J$*8hnsoNJ_J32*g-t$KCb3lVn#J z6iXe&tHgNi|MY*bsMhXgP?#>pj{DDNAVa4bh;pIsi+MRefej4HPKH(xLoq*(TfGw$ z9E^iFJ8!Igd-x@%uhXEZEu-z_N1LpytTf2daT`o-B~ht8utqacFpApO5=2dSDa(^F z)K8hCXGToz7ReINzraSO^^Tmpnr5GHED#WSSQ4+F$5QmWJ14Io)!LY47xDD3LPtSc z9m3QMABTf6-aD{$EOlSM72FukmOf~Q`WQmF9LA1B!1y0Q;Hc5^_l7+BdA`@x>%K?3 z7vCDbY5iRn(YIQp|M#^>B$HpAR)r}KX}&{xZGWZ6D&KlR@?zqt_%ut4Zv0$QUx%bf zV?6Hz%)54q=n1XPSoO|Wl|V||pU4Mwjd!@-cmAgixWgQ|vrr82q)|3c4o;7F5rGeay&U!K-Qw6z0U(5g4W zG{j(F4PjLf2|MF6?$gz{_GgfqW?|?dJnwuSTq0<0?E{(ScE}1T{2?KAfO2rWvhezt z(Y(k^9=a0+EUXlV2~V-wi$`JFNa%(#^##`~b|>EO=&#BJOor#TUN z7kEEsuB)kCgIK1%*Ei3hsznFnay;Z*2r?eN-RE=RAG=t^F_s zP%I!A*{qQc7=F+`dieWnPmV1TrsAT zt`HS9x%C>Chu<@By2HEq31116?{lLJi$>!gg!{;KtTG6}O^MDJ*DC8)$|QdL9c>YY zcmD!{sI)DxBcP<K#UJYX^jC)X%G^VDB&RhF|?-UUf}wRI?bZpJlZC|HRMJwGjQdJ zp#r)r{Mr%!^Kc3CpQERH#=6`eQ zYGeXWi2PTA`o8w|s>7{UtTML%K{s&J>ik#-ek__0pY%y5UG&eEf}Ff2gj(QT2*+8i zHJMTSUB&ff?&&y2%OGitlpRr#AaC6)S5 zA=@n^I_=0R-HNKMoT*7|pOEVay*6A)m#Qc9%P!3KZq{&J;)=BeF>SZb*>{x5-{V9n zqF>S*Pd}_MZ%Y1`Bv;ltde9;M_nDi1u2Mm_FEa9=`##;rLW_)0j`J0mkDC0WeAd^1v~uZq>D#3^X*ePZxlfl60j zJNo5E=F1X7U~i|1HPSy!-%Cfsk-G$_y?zXio#p`{Bu2Kkq(3qoMd^ z<3vtDcD--;6vOs6V|3#ag5iZwt?`zaeDwH|>Y^O{R9OhK58t^osd_Sy-0AeTM{f3a zhhWaXki4~9`xr4diNY7QtbOF8ihhT}^Z?~}3Nn_h^0)ME=)5*`#cKcDK!#Hw_qqWe z;d@WwN2&eMNb5X{T2zN$@SKsM#Vcb1r26ccr)t{xg(v&JUwb~@b^5RIoEcpG4U)e= z_%0WF)`jl$3M^zpc|Q%Mpx&Htep~fgQ6Yu(IaY)*^PmL96-8MFs))dGsaH3DALJ8vP8}07rRM zL>)S;sz62jpZmcRbPMn|)(d0XQP$sCksQK7sN>VD!TsW1*MBpM8|rZcM@~LVv*#}+ z_+c5?@1roz_b$S=CCN-v=iVn$k^i?xi9mlQ-y+vB$7MezO38TiJxg@<+S4G)Jy$*Z z=|`XD29$R>YA1zPTL0Td^d+O;Q!-B8<^1;+X{QlKg1#PA|ElD}0@n#?K)veuxmh-Z z>KGz8`mBs>{;hRavwm*l&huD9&3D84lV>R<7sZ@2L{ zmwq-XX4D06G-c(#vnt`9`)?l^LWlYVS)3I4>xqWBctU{Y9BDx5I+H`slaTrktN0p$ zN)AS!F8}A5Mpwi3y&7Z~ELB}r0;u{BQnX!`zPPkgVunipH7wB%NLx{V457S>?%>!FVf*=IM@Im`S%uoo6IoiSg z`AW^{*Dsm~GrTP))O)t#zfTpfn(aFz^n9j>JcOu{pWT8z-$8lu9*)1Yh=r(P{)LT% zl^dIYl)zVXqy|3Lbpm2hLhUD5$bWYal2mo^=VvW7@Y=_Ma$Z-JBbV}n3|;5`{iYtx znK4$oZvP8wL49=pIc{L0hj1bg)ZP7+wD0)n0Fzp2lgNK-dW@SvK|hT%S#F7O6aPPp zy4C5^Sl0j{)}U;VJhefYIrf2FIsb17-{sMXnlZ~fb2e~h@22gS7Ooc!BJ6+O^2TK; zaz_b}5QFSxX5~&)&ff`7(cth!o5a7}v zNwq=w@9reuCD1v#Y?P>S$H<29uMb;Uj$^wlHMgPmw5tYXbV^FZ`}YcHBqOX*+*c}w zzVTCSD1z(Ze>E3ozMHgGM#NMb(sJhi{fN1EzbTj}nF!^gEm@VY5jbfrVaovevqtgP za*ZMXx;urV7M@oA_eIw!{(!~`3RXRE7gdcvD!}+* z=IH+tF1_OWrM@TnuB2H0cQbUNzZUSmF(>`}OaHswNb?V3@?|xT!aw}4erQM;tjl*^ z1hP5GE8e+-NiF&B-u5l2vpIRy=Rb)!yrO$p?aGRZmbd+Ti- zpY?zD#9CRM&A|1TdrALZA|?f9kUH0ucTvppQ86*t;a73OuRa#E{q!B6(+`kMHWQ&0 zBmLhxJKDE|Gwumx@AfweM~F~GH*g z2%`Vu3%%kyJ^pur|GavT6<&kho2U@9GopY7Oro#Y=~7X2JhvYiC z>fD*W&~Q)mePC2U)$^z* zR{~Jh+x=N50v~P8JBvhHea9 zuY_IA99VZdbvNkzzVM$a5KY|x$Zq9QEhGlSx=U+N^1H1XLrg9+A0PT`yDyivqp?}S zpS~z-@3QGww;AJNaFGA%0bLDuYu|Hp0Qvueht-j0)CK`_K9Wg9ga!oYXf|*(2vX79 zbI|HZfKz%fh*M=8V?2hyPankqoA;nJQ|r6UbW)*leh_^@BXpY#a>`3A+ey$acz=7P z733PPQl^T`Q~wbWrt#qNeHe9U3WrdH+Z7|4EBnjkJ-1o{@X*+2G-?Xu$y0E%z5|r7 z#I}b4Jwfh#F=f)j7GMLat2_xfEIhQK|4ZW}kA6%!UM>UBy|b;J@K_{G5KS#&YS>vFb-u!U}ufHqj=4Zs*AS5CVo1lj0l z0K1vjIv&_Oh$4QhN}x^=L9?8J2s{acFz5dBZCils{M_dLCQZ7rW%h-@QiOur@U4uB z3Na*=YyR-Af=qcZdPxITIYbZ$!u^7vIo3cZB}HLRunnlayEvZSkJJ)3=PIVX)65I{ zR~G-&97su=o}Fb86FV!AlYc%`_=9yh=%f|r$KLJXxgkKAfN$`N5rIIzwry%`BmB1<9+XkbL05!V%g^SCm zk{Fx?nx5fA{P)O>3xXv+I=De#Cwg+4<$FWGBq;(^=O6%Q5r{TGcg1MVE}9qyr>)b% ztA-o2RwqQx!YP+a{9hI^CK}2h*wM>C)SuYH`R|Ula{#0Bg9-?HN?uORakiTl7RGg6 zJU6l_kTiA5W9?_>Ja;#&y+FB)M%Dp^8|?suZuh1&SHod6zx7Ds_xHzmov%^BA3{o1Ng~0s?_MfNwqX)!~K({J#<+QpDHw6y;Iv|1WWtM1@RqTJ8~TX-lT z5RNgsIjgaBIc@7j$-9zqE;~oAJPsvI|CjdvW~6C25<&9a3+M>QYu_zwqYfq#rqzzj zaPj8){WAM*zlRi>bod28Frj*>^c(dk0r~bk#3Dk5!Bol3%s%tZSay(}Gl4AtqCttS zG!7#Cv+g`QcD=~yKb>3cOV$e)exRKweB6D727(T1?ZwwUXd?m)B4rGGCOVA; zyZ*L+(Sb(q5}`R!fR=-(aP3=7`M^5*TftJw#QDAkF+v7r1Ll;YXF{q>> zB{m3AJ_6e8LG@;nO1PTY0-r7#OOHm%0w|9gvqHhkuaD$x`l~hz7f`3#m0Q>m^ zD>IW3mLuOjm2&4l=9FFe^dPzi^XXF%6CvB9kEi4C%#5aaU_mDS(?#)<{sd>-+}330 z&N>x$U7wY}WXHu+1-O8KnP|^WRWBO$oE!3(5xcK`^y?`hVI?=g6WlaAzT$;d-TF`G z-KPGoe$h{2UbUziK8ZQQ9`0)-goHpY;VzMamK!oS@xIJWPs_aF@ufWj?-yEe3wRFE z6R*T^QVwGkc+u$yo?{vbcRpCNl>&||u%$&zN=kshh-^F&0V~j4Z4T%Lb3V`g^1NCn zZvMtw2A$VJLcU!R7aw)?eyy4JDxwsB|KC!pcOenZB=kULW5cs4uaPm(T@o!EdO{T> zfT^V6n3kS1{$HQ}4%$?pJ@)^>1$4Lr6ht3=0miW})9fZRUcG3sSK1(fAmswf!Cm zJ|HJ2M=23{xzHlF=`Ml{2g8y#u$aDLd6FdmC{R;vUIm{=t1qcE>@K4&92Z}K6cq@L zRqddJs?UcRCslBebGocfaV0cs$=IW(49+6phVz`!I~dT(I%r!cfh884v)SFH3Z!(` zo}Rp<#gBUoYX~iMFpfaOY>=l&(OP0+A~mlWHT?glV0J>(EW8113{dlfRtw|BX;>Sv zHoeVvSbKntD$FW`qyKg>%$6g9q2^9t$JJ8q?%KNz-TWaqB?(DMi`9y-ybg9JMKz1e zPQ814?Mm<48ac$zCX50Q!R;T zXl{mT`iNa0e8zIOVW>ET^j*aqcEoCd!7TL9p@S*hWCPBVudxl0yc}+02rzq;*Q~k- zBcm?{&)xjWZWuia6Al~JoxVIh!f!i445ykGEQgpTV_G4n;|P0hH@ynXJk2-wi%>eD zH6R3op1TW-j>0f5e|QJ_(D`wiO7yz|OM^>PufjNCIE{F9K%Lm{Du_im{$r% z9n5`s9TE@{f`-#rL9z$OFQ&4C+*7lmP;-0O2a-tgMQD5PQJ;!$`p zwQRX%Pv9j|@yh@V49(C$SmL9wTOtb6`I%PxM;5`@0ps#g7-T?DSE-7A$`&?*TDNJ&4TpV-r(eX{j-|A}n2REo&9OZ}6Tu6neAiBo` zKf)gp_9566;J??AISK=gt6wOH+%Y?l%FzF>5US|zXQyGAAfWwMuc438s`>ggLeM*$ z_4t1kI+fb~hr6h*`-^rlGO|JS>aYDjMYJ>kRaeCt5KwjJFAv?H9m?NkFvGup2qx7} z_CVFu^uNlqXW&`P{y#sX)Pb6{)XVU{#>;&WusNSROtZu-2q@J z-l%%JF44X2l?tdQf5g7luT;~MBUqk}}rUX73m8P5PSeITusL4%> zcj}CWv*YEzC8Vm5GzL>uVAd{?oYyDRa9AQ2*xan-ymW~KO7E*kHj3J9z!iZggJN7h z$cx{mrk0MrA0y2=6(UD+hH&8Z(Zh05;rL2p#xEnppZ@a`)S+&bk&+sUeGp#-S(w#m zElHMk2|1*kwJ4fGBNk4QB-mv!!mhkO;)Q2JE{5e6;3KU1^&hO7j4blihjYH11TV|R zX_*dDA2{~0L)Niza5!P)4?DCF!QcyVPB<<4_@NUs0(qkh6c_-CAm!OG1TBXRyZYWN zs<(54f3IsyRMol53jfl7I94$ME#2+&FA|{)`D0JGz4{R9gnN+M5fT%BMLQ}-N3cCd zr}u>lqXPFdK$|$91xf#;D~zlHM!GlD3`4U0l9hPAu+aR6Kv5D|#+Lz02^DSoiwFCp zXzVkr^nBwc8q3=OeN9or_5Top!8{}Wi47}U_B$bRkw0&~`1{d%!>z=1Nr+PfRZR_e zA_6L9bKw!G?7!W&!mut(JC7YQH~RO+yh!nAr!Sv*_OCn{3srSKr*`(NR8&IiUpgi* znE;B9zwdeWKXl?CFH`vhQOW-*=mG8(`ewpl^$)jXtVnFA@I6XKQ8fPF@$L^>gL!9w zz5%n?Kn{cD*d~mZ41{B@)RlG2 z{Fv~6F=mW!&Vq&{OzJrE2 z9>B~0`VBB*3mvx#3JIw#GVe7D)yUF(46qkU*^BI)pZoB`7UW|PSPkbr^zRli@JO_Le0{I!FwoP702fH%mUr2=azCcShv=9M$?QF-57FmKDLqwAmnt-F zQvB>(NN(G)Yu44Kr*I4>@$^ zlp2HMC)}_y{BS=l$W;51RKpBB1iwNlhFa_<9T}AErXSjt(UeGp7J+?Z*;?maM4mvQ z+-~}3O*MB#V~<|-Dcj@J_m6D8bC_c!tMLU)HHK=&2JpCf4t8+u^2|JOOS^&aXv$oEdc>V3j|V<7$*Rur); z=E{Vap~eT92S1%3{{8GRHpmOLi!JDg@x29)-P->i<4D%9AN(g+MLYnq-Sz9&E1?f= zHVb<;2y(b&SYF6S@S4M%|EpK87(*oJ?~!o9d5A3*5+HgS#q^}o*Kl|Nav%u?g2*=v zy>ZarI@CRc^Z91RzmwA$9OZ~XPXXCcZ|CQg1~z2k5brW`5$%+KfB^8(XTe?q0_Mg~5>2pLBNFX&Cx8%>tkO_L*YXL@0mH@9D3g8~maAowt^JP44o zYyB6vxle%2G}laL_EF#!LJ3g@-hEJDJ6zMHkVecX;L5>~4DSG!-gOSm=qze+Ht+jL z2DLCDs0K_-Sdb|!yp>O50y}hvsaq=@i~$dZYOr?{EZeX!^d6h{dOs1+J0ZvF->dCY zw;j4{?mhVeP1?iLseB`jGsAl)I7G??%3yRNhEZ%r~Fo?2Mrcnq&O4UE4GMhQCb zAHdRBRG&Al(JcoJqecY}{BcN{A8Y73rRnxx@0m?wR5TQ#B}qdgv|tKT+D0(8AxTFnHcnHG zTH%EAL^eUh9fd)&zK^ZE`vAo`4f{|XhX@M^B-nToVXgoU-3NSOlolYSC~%yVMKvmz zUt$#w6#Q$RH=F?0GZ1wQ;k6JF5>^7&8`)1Yv9O4Kv|s2GfE3ilItsUPc%t=qsTb2f z(EKhpO!O&rFb_J-cMRP6YQg^iEph~^YQnc;LpGpSQ1*J*(aXvkX+4!TnijibA{h4}lpRi7;Rc$cM^h*k9mUd9-MaO+Bj;%5~)T z)D43Jj#EqXE^W(w(jy0T*un<~{JfL4uWbQZvHisB}xbb77SZ|g%sV8wblPd(c2$KB&zdwM_uJVZLwyB~XAiU%z=%5)M+B&i*>2Sk%IL z(u%kIOJN1Z=rI2*Xafjfb-uT2x(HCq5#Ywr$CC*dMGvQa3+4M4YRjpt>`-q4;BATA zu2JYE72+M`=HkL=aR3^{TufdIq2>L3vIZi&i$+`BmNLigeL)EUV*+ zl`kUVl;U;hFvXT`E&6^~+>M<$r`8X^3jioYfMcUodzO%$DyH6QmKwlnUNZWw9wvanr~OzNfk8|r5(v83Ksk!F2eSQ zB^ohxRYCRSKOsP67o~uqd#-AHmmk|;T`6~|HDYHX1{jKduvcZQtgHgocc49j=ng2) zLWD42!0nSKsHC740eKC~oYNu{*R54vyA zff7S)#CeOZlvwsdOhRzuVG@^Ng(ds{hJl6zysBAKEGs-PP*gn5$ixI^6Lw*b_M6b; zLXTizSi>;%3gGW+4}GclNdKJjdFujp#9?oGK?@~4P-p~!btlom-lAt~rUnOU6FkUFAr8*-rtQ(cd`9KYmWW<{3@{r*xa)Y80?-Seyg zrvu|r>&HW9^pC7|Kz&&1xabSr4n;+eJJi|gfdHc>1I-k2vm$eM$YD`q4t)>m+hcXT zV5g1Y&r*QwB~*0ZAHq)}q^Gy*E-*7k9NgGBxzq24-H)>2Sg3HKz@{60%Vxi zH<=%$r1Rc{4B3ve-zoU|Fpn%Rf{jMlg$J4<2+i^{=QjQ>=HS2!nAvFFzUaXgDF(G= zz@WYeQjy#`s5oF%k+x^w?JDws^cam2ftGXl%a^8i6JEZ2=KOo&tdI~bz-X)+zt>Oy z>%YGS*Lwpb%y3y_1@g3yhMqem%C*6v%^%u~qX^#sn)x!s+#yKN5Cf!v!VDJ^E3bYA z6dH1MTDgt46GU`!CT%sq{O$e!0->fXCfAAU0K%iY8O2@&$CSA+CwkE6*iqJfv4z^H zm5*_N=5HO_%nc4LX!{TzJ0_DX6TI|8aFK$V+iFPcZa^IR2y{-vap(u6Qe?=)@3eFV zQiiASHIx#qYtvl}wFL~b<^e5(7c7Y+C?6q&Rs*_;EaV`)>bmC!mV}%DtfG_=nTf2HrvtBqhib9s#uj9y{i^ z_$P^<{eR-6-q@qE5E!9$+XYYn&6gn%Ki-m|#)g{os4MZ?cI~E&&VGTi$|nC`bcT-; z+&Ll45!`ql1`Hs`Il(FI1Y(1?!0*W!9zeu`fSW#_FOgV@{Ip0#Hs7?!39kR&$4Io( z<#W)98h$=n|K3H@R^GmcFEQ#4c1Y|e`RL|YH$7O;cGq4v>gq|U!^tubkFBP77<+wM zO+SP&A36UbU4Ex@)II9k{U;K`@hxw9aY7zGdRd0@k6|^8crUEP1r8d!eUah?UxwER9I&9p{N_g*pD8|6oT`7DMY^ z$$(Q%MFIkF3UxGC0616}~GxR|D93=rHXDGqylNCf9t0D6YSN7a#wANnKv z2{By7bOH(>*#dmTz_rX_UJKRt%gf6}_OO9q*E*mZ`Nx$PkvRY2ETpjr+d~*WtTL)! zFO!)OV+IQczw$oUJs{0oew7{?&W08Hv%st#h;+t3X;p5M3I3^WTl*{LPlZ6?^v84@ zy$j$Iq5e(0)5sA&PROPupx=c6q4O{>=a9ct&!?HWXf{|5qTD0U2-Kf$Q%+3h@w$xj zr`pEC)Pp|)xF9?k8vpj^hlCe@6du}yG<^sn;F}el*sch8_{{(*vHd&uc644hEE^jOzl3i_E_G^_4(hyPIq?D1FU;2NdBh*G~??5niVck=2cmfm<@jB+|W!Y7${mW=#-wTC}0p&aWqh`yF+{p zv4+B%1wD58A$BCT@(=9Bym;XaZMsvKmQ)3X|D&;HbQv4MF9Ae?OzgS@?gz&JIYVN# z_Yd1@Z#93BO|a$vy>^2oDir9zT9O)s>sjW)3_oQXp5*E6977 z0rW9G5iDPAe{q)fN$`4SYL;wj+U+Tg!L3DJbgxPe=|DFMND8!hYoINy60?FG@`N_* zt{i2)>Vt32Ty0m{!U7~UNbe0sx@$U{8igB9(!KOc8)D%(YKiaj-Em)#>>Sx$}h>Qc;6W2&(+ifax$k zXp=?G$W3d(N69{T|nH0Gjaw26ZRK($YpXE0`!fp zXlYmmuV3v=)(LhnQB22rT9)wpQ-ruky~ z+!8E#z3ip`#q4*gAG=V$BU}`SHq@US=852Tj+{7;`l^~0A_t!>q0}OS!k3awJwjjz zP{c>?0kr)(3cFw?tjFXk)WPzv`19p$)2;y6b)JsL;hd`NT(8MrZmUk+JXRg_QxRm) zGP|b8LXoEX`BR&q?715@C?J_($=>|-dIa0#5~G&B;`cK^j{{QbV@}7rmQ{MsJulkq z%;6MTbN*(O=tv|~^I?>y^Ho%1LI0tSKkB{5pkG}=QncK;*iq^_C>?e7DEX^{{UgVT z`i(hyV}j??d-H|u1uB^G#e9Z@sWGe)@lX#_w8bHDYbSNG_FJR0<{)+=|Z!!Dgb73_Pb;GnU zPxI!#2UJpN9K;Gx;6M3x{|kb!0A*0$^*A~9YR#gn%*)Uk>%HT~=Ro*6nj-)3=qsNz z3HHVe#kaFR^ZCqk+uYdO)3Z8fBqO_-&(XUvp|w{^Nz-p#{%(g_UF8!`uUsBqt7H7 zW4Ajv8oI_+7#q3R4s?o5W)x;xNjlgjkoRfO%;NVaQXy<6_TDiP5TwN4TZAVqek^XE1$E2uVis`5Q z^!2Wk=oz~N`dtbl@y9U@oN1@Eip4yuRw4}7>@I33k>4Y~by+k=Esu-3I*$0yT?r)# zU{gLmHhEkArrpcnU!}U1@@9gq>#sPB^uEB|b-ySgaBW*3dADF3KB?n+FS`*ZB>F@2 z@?y4UAnl9ikxx8kcM{##*01}^SGspBdds9=v|(^$yUbPy-<-Nr`XIb+sg_`Ppy9mz z$tn4^lr1-}&cV$`XI{A_h?W=aTL0J&s_HLQfFVJD&gPiDKJCD+!9oH1rX}H=I5qPlk`ylMy zEjF9Jr%@-YEZ1Da6Pv}}oo!pUgwM^qoy)p1Qeh)zRA#k!^mg$#+*8K(uY7JRy?#cL z*GZJ)<7Y@;Tg^!du-&R=4~~Wn5h?qH2h2>>VUMqtzwjEu+FSWxHCSl_)XPBC7Xnt@r0FU@D5oP@p`x4tLMzab`uZT`Sah@TQAbPP-QBs zpS{LIvnDaqQ=WmZ`eChM6`pC0(Th(^X~dhBJ$)-5$=4=cbbZi{QK%#SfvbX|R7N+CcRdq%Qi8{SwMr%fHG@gBsPFtPvn(M4*JkoWq)t6?IlBgZc&nva z@zkA9quxuzlK8x-7PYN$vVpYG5%Et+Nj?2HS(p8DX+mcnchI>ML``_>+<0!K;(6i$ zck@W^i@}r5x{i3)KG(D{ls`CKXwCbZ{banh=2#i)uZ>en7VBnraOn19oHuUDwHkF$ zz0&mZuO#G2Z^y-Y=7rt+nT&#U0Y7l?$H7El1?m1tDK-VGVj=k5W~pj8GFgSB;35La zkP~$GKImfa0b2$kj#S37cxk8;diLP6&VRikM{ekGi*63#2rX6OP>Xs90ThcDTw9uQ z^xIJ>Zhcs$aa>N7BnkzoIuodPOhqK87!)U9k(0_SKxjsxzFfyW%nHq}>S8thB70pin{2AT|LF zIh}?NABPrTgG9TQ8wqgef0NgCvN`J^vcv7mM;lVLMj#krcd4PGc)#Y=TLFHD*%))3 zH$umm3}r@O7wlS>3^nJuW-rA2G$q0%g9lgu~E=U>Cv|$%CpySPtTi|yvI{( zS8GXjAbm>kG`ZkYLG{MxG*Z@LodI87GV?H}OJoGO3{S;*BBlhi@al2&$!}e`5t3qG z7<7l<=J)mdol?dt_sMT{d%BBu>S+4)YdA!+yl7{R&D-uX%@Qf5!6;sOLo*ay~EH%tR2BA>@%f4;M*vZvVC z7~Oyy&C@c>5$kxq^{Hs?ZU9;pO9w^V4+S1lSGWp|J)z~i$7!NfauHq9M^Qo=ex zg281I#d_wKVb5cg$6}bKWtYB%E3}jtKlrKJ zC=NA$^R4UK;f$UPGx%N7g{)0s<_$7Mtt%;4ABLNMuTAewCEAPoH2(AC`^@Lef3OfoM zjrA4zBb9WgLXym54?zG&KJn*opIlVA-2UN_Mqj!6Y1Co{&<;sNTUuJ4oFKym|KN^q zWc9JYo5f(DeA99GzG`>VR};EhIu7y}FOkzu#Zg8j-y|zr&SBUYa@>}isWw;YqcL7D z#jh?)G*Zd+$u9r+`A+@w?Am3;U#byO+70DT$L@ZMDR3p?B8k21ViRKO*IxU{uPbtv zI5nWN3{O`3aof@#(yG^0V?#84Ekxm3b3Zu#6{{geSo_*56j0~MbhpS%7rTV3tz68$@lx8 z{&YIU#e+y(>CE++hgY#5z0p~7Fww}S(~^GwL?&@L(~P0;eyyj&ELNgjOKE86=2ncgl}utUffT8jcMEGhQ!njjC%jYG1u1lq!k%;>yYECtY4lsKY-^%2XPR^6@&!!0Mmdj5& zpq}FbttGQN3qQN+%eN{&X<%U{ zcfLzhc#0BWfO$)+qN>iR1v8l;y2bKK+}ms7>I;+KbW^c_}yIW zJix~2KH8&+H`$g7AZ3$d>`cjGv#%X&JDqTu>vztL$n!BRrr37p>1=M59MfaFoy}5^ zVBpV+tLJb)zD%}!&Vye4IrBY_pKjlUdAf=g9&3>fdEaI@qp7xSlf2%r{h5_x);aq8 zrJ4_YkJ3&lw+&Yp+v!}3V&({UOist&dewaF5XDS z*R#D5!DB)9Q*GCMp2rri7UwPg5}Es4M75JM&t`3$%BFfW&38L-O-b~<$ouozmjplX z+lYt2_9e~{rM6ZZ+2zfA|M39lepsUUhawZ@k|QrwEuPdy(x4N>-1UZMEkxe;uXtWH zUkm-+Gr8Oro6vVszm7w5Jdg62#bjqOfWW}nn-imF=t&a}qx z7lHK+7Y@(c2XI< z9S@fV`wtc=PD^dri2c%q$Kd8M_uHV+euw;nb}&3nvJ*7kDyLGdcGs7&4d*i1Vwg_^ zCAF>?@qS@BmB07X!Odg$w?y-@MV;eKRpIKv%MyVr0E;3*Euxx8NJ=6?8^SP>ZI{kh z5c1P+V`=UQVTMd|4-$87G?Kgx<#VQe@#{A0!SXjtAdD@1Azk>UgER-jj{=Eoi}7gt zb(Pg9YDNf@YK#XPC8sqnb2o-1q#Vz<9J8erQLuUv_Zi39mQ>q?mPBKM)Luq!ra<=b zmqztpMrY||Cx5->%&9pjbR1m1)qcQGsTbYxuJnD2_`^mPSMM{dH+iw2d9jm89bfkI z_fV6%^9-kt$RCID_$gNZCubD#8V1ftz6p|OoO#1GUH@s2;oL)s7k1}v^_Inl$}%uJ zCY0_le~y|M*zl-GeWMgCbMB=|xd?;pMpMnZMjxZE@n$Q``(68rmTBJI>_V-HrtQf%7m`{r@HM-S4An)w>k@_gQ{w2=&vqmw@tRB-e&OA zNy|>jGUDr2>@DT}^weHl&VopJMLFzw+>`dQcJj#UYZ{xQvzPUfiyY?}XPKYi7XE6I zNDaH4@*^nh<`}iQWlTTu$&xdE>g;9uo}9CcmyLD4Hb;9D^_-^glFn;YnF+8M?r#w- z5Y@u_z}1xg(V36L?Ag1=w>Ytx&OBp$-hN|~Z(h&Mdi&_ul-b7Ca`?_r%>+kMyq>Vs zmdu@qc^-fK`(qEfiS0zXRt0Ep3k2a5lMY>PAWO41#o(ROoj2Mk$+$mJ-n+Bp)L(Z? zy!Q90NplZ{tmPQ`V>tX_ zZk6xZzq4hX*_;}r_ESDK3|G{4ZFIwFFLKR-beERgAj14&oXl)ZuaMoBYjzxMpXU*=hbL7h*xnxvvX%OA~))U@)g%5}TD3wN7xp)>-e= z)h}kAW&6@Gm$|t;*VXNwd0%)il{euTZ31!`1re5?Y(1}(GoHdv{5XA4FCmnB`oUb(k&w^yfpmNh6nBFwPs&El z=E@G`Q_7wIkcx_dv^n=G}aZ?}>{SFckOiHBIs3DDb!>spe*~ z9#!%k>gdchQ2oxbYqdaNcI%hc+$LqeM&-?g?)W?$MuP4p+*5-lZsCDnl2{m3UIb{6 zmYMQ3=E2I#tde00I3vDz?Q^4DYsmvCg}I8)gA+A7UA<)?u)eI>*Qi5M=&NG|AIM{+ zk6_;&;Ib@h*~POXU9?!+yWCWUCoV>pa4l_Ah2ZXsqF?U0hP#F>d_P@s*?(iQvsa$t zB;o6ewPU7;ue~a_r6G7uIy5xzjS!awT0Y(WsfeQo>)UgC?t7Mjx=|BPgq|E9KJoa) z!wcpMJL6FtRI!CW-bjXACdj38YFhAF^2H)qb`qX-*S<+kn9y!UzF)R|p;PZpXq*6E z&ma{S$I$1|#o_7~3|hDytgeInm?_P|6fl}O?KQ4>ERbZwkawjD974bK|`cj%{9V7hL1TfTz#>x~};lbUlb z!K*2q)wx3?aFv5n^wmz4U17MA;X&%VrL6DQ?>i!ck8*N zo9Q=$e)Sf?qW-EISOcSd!^0vmWQ;}3W>rO*y;suq&j@wR2t0N1(~96-Q`sE;G;@t> zgO#MWcwJry<%+HEINJU8Eu{N$XG-8O0&zTP_PF;BQ<=4$etq|hf z#zRVv7_@cVsv*VBGENLQJ|ETh!9llynQQanE(*}|_@CsqMS4GpEwXLY_~SEO$=bg% zSMhdA?I`$5Zd!G-!=`5M+>BH=;SiIVCOJu4SLTlaKF;KU%#bDur`^ zuih2Cc0LeG2}iCkdiCWDjk=Duj~bQ6SurxN+h<=}>A2G!3~Y7r)Kcqz)|h?M8Pk2A zQ460k>N>~tt8sE?@x|`25T}$5HtJ92EI(J;%t9R1qY^vcfOh#E_km@Bbhet|T1|(h z>%(m6!7;)`LZRNVJ8<0KlNrQvSvIa5I6W+iBqUe4VZ;=y>rIx=6!tvYz;=)*0Q-`+ zmc-?7$+JrEN@D33Avi4*B@QI=cB%0}Q*{#@FsxR_ZlH#emU-)eEUtLvpy%C+8-A=| zJrAyxlrQbv&vn5=4H_~vPfBttd036`J6mip~fNaTX za80$9os0GoPRN}tAICin)k-oJy#OWG-u2*x!1!?C{OOhR>;n0?eC>R1sMa>;M=ZOY z9x|Ekbk=rcWvXwS9zX3PEqt=4$=8kwmTT{4%xqRr&`JdFrDo~dvMfzCuB7P%zdFqP z9Io~b4+*vqZWB2L=5FzIq9JH2XxJ zumg?&YHq77{jBtqpW(wg4CZR&Rk9>PaV-z6@VOpLD5j6WMs$I+gh5RF%?gA{q2e8Z z(yG_`&H8SG^RBosqiFRbq$jT$`8D1xs@(9f{6J{>_yLWm%FOIhnI&t;Dr3x z7T!_WRdB$$qx0nx|JuPjS8j<%E}c^TT(|@l40$X)bZ-bNH)}^Zub93`4=TeX%41Eh z?277kno}y{F}g9%m3K9n@>Hjix|!Pd7kxE%u+(%Bs&zT(Vs8tCbOgCU5BIoqZwFO| zY3Pt?UDn$S)o4rBl%~VR_{PMl*C?W-d%2QTZ zIp1+YHncT)qt4*5=Df`){Km{B)H%etn%m1TJJK~A=UE|cYQbuk^usuh3{%T{x>0{r zXS~USEwsS6KVp(Gm91n3y<)Y#*{kdCzjsU;NPfXLyWZ^u%_qa_{b#2rejUNaExp4h z@QapKUcg_NZ$tlJe(z+(=-fk!7Q+7O(q3)7$RCZLJnS#ISZKX?6&<8;tPC5#)(R#Bm%f-b4pGRYvE)gK-CliZW%@o8r_ z+r{6=zUi8vn;iIN_O%HNz3r-L|%*>-_^7W?moSErs3s#~=9-@I2f3+?IP= z)Q_GNLtbzqZ{!OR2G89s!F-z?{-|GIEY+=w=r3hoUagmQs)s7r%B7Q6pC+!W-)5rI z`tU~RTtnk4=er>XH|OzM1KHN}UY7xV8d=H$ogJ0^AWs%&jr`=IbwBXV4K?vNoD}&0 zGjm>>4~P7p5^l(`Ccnpe137x`*N3Bs(2J`RNe9B_!ABZpdNrR1 z9>j?R1|K=ST-PBgUlqV~QPzrioKPY%aY=2HQG|QoeQ-YQC`YvCy8?= zci+4Gsu^VBhIs40vpdJCR$3c9VVl`urhHptZ1e#RI1c*VFEgL6CtQz?3cvi#H6`L= z>$$9>P~$w-Z8w}MN_611{7tfT`&)d$hARQaVw}H7HbI^;p!e6@h}@+NOa6$=Khs&DhV6Fhls_si=S7K=+>b8m|d{3r=d81ckbRK#E1 zGH#B3D5|zBbL4@Mky>pR8|;Zul{$;ce8vJUKJ5Xo_A=E?*KH;dZhEMv#;Kx`C9{v2 zs3*rth7>Q)$f%r}gwp3qQWas_H9a4W__PCD5B>w3m9?$S=@A8|#I2*p1aeAlbr%N4 zopi9iE)tZXkRF?y6a)pBQX@lJ)56AhYDb$`_OsPM7h%P`H+1nZHpVh~*M7(GtC9ebz$E%qE<-BaAQy9#`3 z(mXWkp41*Ct4CF@QSxf{c!iV5Fm`i))UNv+{aNJuxjTB4CA|kJA?u)Z9tIPm$1uJ@ zN_oq}vqcz=ZreAjmo|gy{#T4QyndZL% ziq`l#L;AYmyMFUI7m^yJkK>}mF-rWBwL|ZnxKkk*1m}j59+=pE7Ssv4eCtYs-f-x` zaQ`h~c0Y^L&w5LW%@=yZA3qySue!7de)n4=6@HeB#39L;9z_^8J^WX4h{@jj@&VD@*dt?=yPnL=Fz0A zAX2$2f2qrQas0wYwS9N_w}v}CZ&mF-1IOD5L>?e;H*o`9J0gx|rb?~X3t`FvK4uinsR~8f^TRlO`iR= zxL@;acVicklr0LU>%z1Jiy3b-@JJey3yh3FMSRNKSidjCy1wQ9Q0~_XE6@79DWy`` z*!AMs+RI=2gK#iUoqg^}j2_{wJC?~uEULGCs^^#4Zv2v<3@>%SOS!s;!uC^k@8~}k z@pH#TWSkcj6?OaF#5V=BcH~cli;w@MId^IjNBp=I$eq9)gPjoj%@RIO#uH*$XK^o0 z!xpgzEtKD~{&)s3ZVaumv4LY1(scnbX9;r25nyCvkkg5NXV#r5nC1P%+j|)%d0#`j zikFL)BngW(m`F_gM#s!;#_Lxqe!GD3ZVd1JSTyvQ+=?vNr6r?e zUBDLLcq1rdy_hV{BSukU!!Q~8RMZMeEX*4nD?z&>v&swZ%Ht1+^E{S^s5IMbNy?p; zj;`jk^%@-z5IExxukq%_#e3R~FoAOgomKI*cE#Ei@Z>>aAb&C@WSv9!1UykV!KWWgmVws43&I89-rBCK)~m5& zhxkxcsA^G4QE2n5tEe6j%f1ZCBH0<~Zx;VdPhFw}^ajg+| zH}~~NP{m+jT%=be>a^$r`;ATd6F0YhHrKb`HyG^*rl#a3I@?gNcIE1)D0h;<*E`2B zXM8A{d7Sb6RF0Jj#YM1IfJLk%g>;rq=}Dm3jxUa7hz;YVVVSg_juX4F?u0?pz>EXHQrFcCIQX!w7pvzDC;WcE76b~6SAKAvfqXS5b(@uZ?8c3SQJ^=Y2?CsE z-B-|oEVrPu#Brz?nz{_>(FSNak)VnzmKe#qK`EXcvIG7rXuhm~>$Nc=Ubzh8@dvRm z)8V?5+GD}BQrmNP7i|Hz7!cg2$LLoD;H9Nx`9L8PX*#$ZyWAahkb6(p$}`JPA+4LMzD`9Ig7RG~KPQ za$Sw>wzEGWQx_zYrePARTfGfa#Abe6#E9Wk;8UNd7GsLkhM^LPs=qGd$VY~aEKFXz zbu7YI%?xzQ=pEYV-BzeWo<88AB3f!FvNOkB23ouy@kAxhf}dwhoC;wtl>|TTX{sL5 zQJP>J#jsmff49hMfBU37GB2(4buv}N4+~GY5|t|l2Z?e_wW9cdm7El-+7e!z5*nMb znX~vk>cfCwPmGmb(@x8UB!%>l>2troq~RxHFxQm)7&vw{ePeODI=YtEJoRu`SX6|P zn~3X~DV4ALWCS3I1UOyPHyPz+G+bF6X1U6iuJwKK57CY>qx==Ul`3l_3O^w6v9R9* zN#0{bb3`sEXzQ7QW#}I^Y6^FGtJ7)-OTC-wv}%lp)wPAO{)??eo_YtT0lC7)&}{R4 zWl{>|H;&XD7QPHn?xBnG39enI!)nq2Jm07;HucLe3u zHu4TCcphmp&=JU=Csjq+SAXKp+rVB~p}2T;x<~ zYxt@moJvcWX{^NU&aiEufa9 z{DP3YtaXB9lWiA^0vF@h*R6E_m2Uf}_X!3uju|XWGmi)U@W8OqwRg(zc#H&Qq=nVB zHI*{vJz;lU_<0+Hnc42VWW9KU(H2a)kYd7NaUdA!amszxT*h&&en= z6gCnGXdS8(OW0C94IaTm0AFr@5%+Cu5?xfRAii6XML!V(1H+#>BWBoS zJ(1$$h3Zys)H_$@B1*?zw09zHp7%J`rP{TmSL^AtGu0c@buYf@%m}=i>)&o-VOB$? zf*}oIvuThzcQT8(cD1u@xi*mxgJFo@Jl=#Iml#zL_3*nd+i9!wd&HPG;hs#y*?jM{ z>)hD4IRWXSaa>tMr|p$@!_@8TIXZI=e6|(_kZNHtN0?(;jqxdzrAsCx7Sh4c!Kk6- z9F;PK2N19R`KtQYyQ|e-N&+3&aioSkE9t8bw}nJ-$magNMv_(X5g+p=f|avmdhA*J zWuCQiK{G*lB?ylo-ZB&xgR{T`}W9%jpCiSCT^fs z3Y4o+33?%Xzsm19o(3UNRRC#NCda0r<$dpP6Y}o2+P}|id;oub4TZ;YHy6xZrY@bt zhp~0xh`(O%F7Wwn2H386FM&q#9&AY`n8c2HY~XeEWpJ<}>77`P0ehycLS<{&TULbK zD?Rja^)vLS+fuwv#FgPuzRKOHr>l3rFR3l?4W|igrK@g4RSiQ#uP-Rw+0`pNkc3c) z8G-;G^o-7jtXEclH)Q$UnZc#)YEa5pE3Tfwbv?OjW_I(JuasI+|8iW8`IqnrCX5$U zazg!^5B}0(K*yh3>`VtGg8RN?mp>PJ z`kdMymza>2OJ%3e^Th?hDmj8f9AoQrtd@t?8H#n?4Zi{?91~Exk9gho>+a+)i@!A z51KFY4qdmX-ArtoyR3<+Z{Vo6T8(g7!M*~1CyoBwOb8{dxUd`ICMg05j@JF}qFo0t z6@yfQQB_mgImxES&RivbmhP0*YM6^YoYI?Pg9`dySdEw2tLGdFFV*fwZc-g%X24*i zoMiP4Mgxusij;(cAP7by$|s1>ffWY$MBND}GOWxe=iD)yyvI^z9djtN-xy!=8U;-h z-RGMb4XfA0^ji2D)qR-?PY3v?0TZwC{mC|~i65u5zkf2Zm!AWGk!X6X1w#_feQxr6 z^dt8>LgZV7>b1F2HWpk4FASDdt5c89pC-XvZ@N`Om0EEn`1OFxH6f$;V+lBFo*FsY zSQuohG9sXM*7#INym90+8@ogElCd0GdMoAveGk9PJ{MBibxSysWy5aVhdxUW-RKl} z!avA>jA3?W)}c%Ar9?(0tSpoVoR#=m@VoI!@t!v1nxO}K)U4S=7=)v@Ui=;v7Baq=K1=T~{WIMO%#}EiPYwg+$-}co-lb%lr55i}5^S;UL6@UrT5<3fSgyba5L0cohKBF-8#809ge@ZW&Ga z;AN8kX%~5T{{I$-bgbM>oMX;97>pwviy!%KfEh(nQqmPrt0IG}3C+8w0??=`401Bc{{*&{o#z&m1VD|4G)FMw!U@zJ z7)%n2L3^5#W}z_&6n);HCyp}FQ+pXqENdvB_Q%D-<`q^rj+M~b$Nn`bH6N)kef82i zQxrdsK(Lm8{bG9Y(LJ#-%xciok$|BxgL7N~^PxxTlta(*UW^(KNSAoXLj(5!5Ka&* z-NYtNynlu@V<^*egBUb4Jp35as3cG~ot*V0We9AENkGWz`Qb<;hCZjMfq z8}UoetwW&QX2E4$@|=j`NoMj`!wXC@?~0s;bX75ReanVE3~1akKDS0LWQ^fLJvecSJ#F-3aiHd`=_ zMQ{oPlBM;_yOop-45S#C2bKRR0rm-aaL|Js=0%Z+2XM<;ODIz$bTQW#yGkNVuy?x@rz60+s3Z zt|b}%4msIaP*`9dbotrjNdZZFVmj$DBll9Y`>Y?~Mg= zmg+dm@&JgZj2A8tfI^8on9pNqV`HP(MF_x{K97BJTnY?WE=S{E>E1EqomBLg}>}Pldb4DSWz}rM$8Vmg; zM!qo!128-7F$jF`!L>O-j|TKK7Fv$KCvL$O#{mxm`|17+P@xO|oiaoc@OFwIS?h!8 zFvQ6s%}!7a+s(8eL1q`AyFzny-kbM!4%_xtcvf_b*M3I=R6mBc#2tCo|A(;mj_bK^ z%T_dQ)Cj5zBgtU_dQ~L@Lc$Zu=hbz`Hp=I zIIIPSTze@im+@*X$v0BCMJ;BaNYT7n|9YN)o9!x741&+iF zc2a-=Q(~UnL`CNwY6_2KUT1OI4VdlZPOgj6ReTZH%3_9JpoaJ~Yt1lsw!dBQ#hQ z4gT0SvRn9pnzC+n5#-H(7>@(bfqgn$fe6EA(d6ZUMy%=o}-oU#le zTD=mBNR6Zmfje*8Fk;jUdimccav{Q)aVuw!26XsEUslnBHSIPC-%hfB5N1cvE2M;h#0Y+wmOkM|G~X7e2YGnb5_FHxcGFZ`{AnfH`GM{Sxd8nbP(nkj?j-g@xx) z-9Y+{;chg8-_lio|L1__kA66n7CJkj%ei|kRn?rgS8Uqz)ho-sTzmHSZT0R4EbQ`% z4)3;XY1PGYfQ$Cl`}CDnRjX@Qgi{b-9|w&i?n6JcvQ1n&k3YX6G2zO@$S6s>z5R}g zuX*>ev`U`fMFwF%T(q$Qy9Wmc(~GCv+}sG))gJsbvQu4#+Y}WPC;(kwCu4w^%D!SG zDt19tAosZQ*r$qVwAkB^eNvg8o{q+*as`V^19zEUeU93;SnAn+pxCDV_(Rrq%T zPQopnC>gJmdLqt`NZ@fgRDWG-u2#1vY762>Te@@wx0f=XA}Fd0M;<&F>Hf@i{u#xI zXohuxTiEc=6ov_S(Ur~M^V}@&(Br*n#zllYd-{DQEx^!M+kcUpi-dp>^3%yR{u{Uv zA3z?=|Ki(Dh4t(A-LsVbc@Lby?8Q_iIO2`W2v}XYIIdxh&IGL+WsoL`@ia7)Zx9k< z#G6yVCWsJccie7Xs=>GP5$z`bw@0@##URKX0LP6qmW1pH$FW308pMvDfIi6eoRJ{^ zNv2`(pCnEt%sh0f-Q=n#N~G(vACN-p3;8M8C~|!nqWCzx#_BXe|}m^Z13?>@&E0l-rHWWc{nn5&$4KTaY}OcMa8RB{tlun zkS6j9fcht{9L`lghTgA2na(ZmdfaR4QSjJ-0RZ2Ma6T*1GB%1gM}Wu9IXcPBxoCC< zsCPx8_EC$X3>R}xPtT0;&u+V?pYAzXo2@bC>i?V_%x=D+uYy|okgBgx*)v*wF`9`a)#a7Za2spmv+B-MdHMWoz_8_>6%7iGeC8lTB{-0rqY*x4$`_Jbq zH5LfCy-_$x1|&kYMdr5I7;AF z{d|aJT=TrYFjnuk;PJ`HeaI41jq;Wh8Jj%19-HA3n{n6t>0^Ppz%9*vis&OVf5od; zw*YJ!1`gGcr-);lH6u^B)q<*yn}ZgxyXA^3M$U7XRi{wkzK~%K?4UnvLQDa7P{|ZR zD#H*xWN=p}xUFC`WAjqAyp{JompeKw_MQ+IFSq)U)6h|Izq>EcPOBpuKx_F@ML2VS z7oVRDPe|xWuz)~BfA-cJ;F~%jUCUrAv6$0O^^STGXhc|TH_D2U zhxFzYl6GgGSTAK674L^gXc2x(v7eiEcw$_{h0|y6p+hS`NTZ^V_zf*1ud}QUBUg`A zpXC7@ggE#TLMKQzAQFyFO{wBa6Nh9#;Cp0bhX8B}z1q*5ll%ylm){`32JjS|*1 zVmt};$r@!7HvgDQNr>+F<>WeOv-rawIV$jpxjw16boa@&9&x>@%y~;W<>AUNR8>SA zLG|u0xOg z;5!j;Ctj;aZVV4J2;zFG!d%6@Zt==q`|+Rz_ai`fT7gd@rUbdU(h)LlZ_Lh<*b3o_ z)R_-&-)lYeKY^XYVqgExBK+TlNlcDfTlGB=*@Ctxf)toSA-OB^v84cm61pvlN2!ql z>~Xn<|IT^a#xGQ=6tX73H@XHF9K760-0WmWgN1kmxD5Y$OY)S&IfH-m{AQQspMQ)= zR|wM!bJWCrP5S&#DJTF4z+OH2ZR zOnr`fy^gdVXIQ^TMIE{a#t<>H2gLvpayg3ra&mI1Q5^4M1%zY23WeO~S@@&XzqIvB zHDX`$?%2VBbL#TZluDe^6bcf8>e^Z*oZcH(KF+!Q$lpRk^J&**YhAg8!4Hjx_&Ny< z2AC_t(m}OQq>?zNAPgne@IvwXpH&wfWInH4_T%1S=gwgt}NX7{LH0|CpjhZk^Y4`B|Nst#H zK@uM|NOla!>&YRxX1mSpik}$OMQMpnKgTu_t$FN03+r7QL2|%4^KEN@8p-HM@UCHm zWN@N$(T*2uRxEKhKx9B;cVW#E~9X;6uc?kB@jaW##>9oQ0rB>h8Jah-u2`Y`{vJfQHzc&4HlXR+ykWE zwp*8lrVZG!X~eHWCmgjM9qpR-9zK8kAKrOhR3Z9oce&rU;P(jNxOUnO4J95?k-ShX zC6r}eZ~sGMT7hp{wL1Cvy}=}QiH^Q#j#jOChjrAERVx3B(5$!Y_P-@;^WGM?jz?{>J1H9=xkJI@10-Na46|>;vyU z?Rj2Mn+_DAFJqwCdSO<{P&iib`B1-R*dgj`yq6TEc>fpa_dtqEPAm3@h>jF_nf_B7 z_V454RFt%$NiLG?;5H)$T?A!7Uaa9ZxmKl4{4*7x5ZkikIeg4X@!3HgHP$Y%hvaGt@uGGYVoFfh zw##;ASurCkf;s&%3ZThQgID8*>rL8vQ-p&v*aj`AQU8uXX_@jdB%z{B&d7;~XT#{} zYRh^L=W1ChIFm1uWhtb83TV6d0g#eNx{ovSU|p}{E^!mlT_>~Z|OQtIqqE>Ppt)iEJ|Bkzi-d%{fgibaLTNj0QlkO{#f8UA%@`4PK zDIqGkp%zD$C~j7(*G*#x)u8-4bPLxc@yT*is~C{0WJZ6H$dZSj-kNn#JcgXalG;L5WYE|>z}h3l0L z*AojT=QR=%fA6i~Fo&C$7w)P&YhR07^{YzT!~LGxXfdmLS^hOtXbWockbW5^xfJ)@ zbDR>Sn-eEJoS%pmPujPLD;c>HiBka%K0ZX8)e9L64Fz88WhnR&r@=bq8N9oAj2nW9k*<}L&M%Pb=m z=?d)HtDr&|N$trT99|y(s=xGruD;p7l(3>(mU|==EZ7v?Q!ar}f*c(k^MWOTlq3#6 zpf?jXD9RUT5ria!#%R$@8D8G#>P%7tMxdNvkQM*M_9F&9VvtizW9x{+)+%5OdRzrC zi(a_-$ zoI2M{F3iwErUeP(NXW$2ty{@!kyaCYuSk|rQBnkPu)p!abxz3>Yy}N?W`sX!Ru%X7 zagWbxWAfbrmIm)>dJ570N%TjcmH_-*O0zr8at2F_IHrQUOHKkxDa6oMNRNb%X43;Q zs)WTA52w5I>o1&4|8Gd z6yOvipd=V-hL$tPm4#W4HI?ym(U(b&8VE{q5SE&3(rVy(w1nb5HGItC7~zj&jgp%7 zVc7}ojGp!GpEtOAc$j0I277D>4RQ00tJ={xPQ$3`t=%=Q384#gD7d)-wrx{whP7o zz7oyjuTzmvm_z$#cakn3n~4~DA>SZCcTx`r3Yb*Wi8lu#ZbI&Enub>b#S2peA?Ig^ zEVb3P&|xikS(knCWOa2Nh2mOQ zxNf&y(%j$5CAWJpD;4bkM~LJo1f?!FmC1x=o2xPONanwui5%Oh`?Pr1P42oUcCk6U z_x^>-MY(-jsuBAc${^3IyuDS6gy<;!l7@q9dTGrs%kpKK%*Sm1(+Pdn#TS14iS_%V ze;-MbHydKXkcAc(7t-Pa;iw?0rPra;S<<30p*1h*nG>bL$uN#lW=#}1$ zrgq*Y~^pkGb3K{MbHbiA6 z&H@!X@cP?>#8K;p)?W@$*jD@kuoFj;n>l)tH;oq<^r=~(4>;=Q);XhXfAeTQN9in{ z;Ha3gAOA;bhDCmctG9^+lHMi*`vG+aC%DCw_e>#BF97F?UcJ5ggZ!R)Hv#8kmFr}^v@5oaJ0T2yf0OTb|G=g zS~h{uTE5wfh&+gnOAMAtR)Mk-G|+9hYH{RdRZjy&)(tW39?YG8&y63(Q9$AkFpDc8 z0M|_%O{N(KG5Y30=y^4XAMg0uj@4>mEA#S=76W}w+?MChpJ!dY+R8=U$jB&!sqD>B zxFM2)Q$r}nvlppjbS2+m}+Roo9nm{Nj;81L2&P#O5uf`T9QLGdmH%j zOJTz;xiR^uRi|974K?JIZRN!VmbxsyVhZFRqZQ z3HOhnUi8#Pz>@X>)#9CL91+z$jeZw%{H$pVL}b9DlIY;XcL*0HRHdtEZIQC;;Kngs zOp=otea(sC_Ef~q|KbK&au0J!a#Jmq-h-MXxv^GAq*os$ySBT#n=Al=l*BXT1>78= zL?wbx78aH+l04>b`4`k*P1+E^R7EDkq6VxNBJS!)Y1wjC3W6_Uk&XHa9~=O1N{)_? z3%PXeOL{@Jht)zlq4D(W)MhP;*S_U(s>lrm{8?zqR94NaEHVDt8>m#z78e&YN##>( zm1bD`#sY=GF1dJTN{SJs@J&INWHddi+fNy-YiKgUucQuy7NKo71U_Kb@lRbsOy68T zA8cQ(4yiwoWb-!uW)fu<9kuwCcZ*?}^k(#Ap-GRpH{)$T(mWX=7WU-%Es+|7A4e75 zu9hkG*m>$seDov1y2Savq@hx9;n!}`-2Y>^eJvtoRQ$;iDbAvE2_8mVn0?L&D@cC>nfl zxRolnQB$FR+2`++gDVPq7i)T|5c`TzTp}?YaaF)yQsN8gLK|NtcAEnlUP$zZ_)F#V z=_h1Wshr}m|KTm*@AJP5B*woPqAW?4;YjmTzu05S_Q&Vf)G%wRG`}v3zK!cEzrjn_ z_TCrs?-!QEP+01=^Gooc(pwLvPs9>+?Tugc^+#zoV(5$73M!VnVj-owpnwlKGzTLXI{>l9#uB+#g zYmP5=_mL7QPgpQ_ialN{Qx@>{Hd#O0OBW<33_~fRfS^s`@+y9zH*siElZz zQrg&%0o>A-GKOlZt0+>tp^xcpptfHLN0Q7!>)C}1A+s8PThFu> z`k6#v!HWp{aZ~6?539MzxTY^^I_+^%5gp|js6`WR^{tm3kR!m(X&-Ld`3n!XBuN>F z3jOzUaR$G?f4WI)VJ(Ctb28BpD+EYyL{dcoSu5rn*yY5WZ*!DSEsvvEbUgKh^?l}Y*tmJ*zMCvX9&V%ok6RR4ul{&!o z5xaK&+D7jIkr>eo7aVoxf6joS7TeZrLmcts=u-wFnZryl*>OY?L~b(fPkLxD{Y0cy z=cY;k%?!L~vb_{8*`(tx?30 zBg(K>SSIJQab-zpt%t4jcRL#^%biQuS?qs|kMG3?6uGk;5Nzm@XD1)l=w!S3Tpy|V z=U-}6VY+R8pCh&df4b!O{`_cK26zo`Z3qDO*x4?R54bb-?1 zf@>#xqfD!`=QOHkJ=ZsVg_|L~?P(~Gn)jT3pyOM(t^V_|*5|BwyBF8(>O+n8Qt&)m zb4SdC^qQc1aTV&=)=lxtD}&bs0GO`(Rw{u{|A)Oc-=2?8{9+>(Zt?^(Y_wEJ%EU_p zq3bugwOZ)EgbxRTU9EmT3b(8~8`4H0iQniQ!KdHDvTD_3bb~Z*O?#-R#YV?1Tk-s= z`Tykj7|8K(R2G!`^+Rh76J0Dc06TykY zSV~~W4TCc^xQzuH0cX$mW<#ptDu}3TMDMyCB;4Nxt zHtUrw%p51_Lp7pM(0T)PxaI3Tn4)QwzP$rEHqc1M@DT_xvmw1PfI^NN2>q|JW#Jlv5K}5S} z!eW2>K=0^n_k3$j5(+8zrXrH^g>Dn-I)}X?BC$4zSVChT?y)s%_Ay{q4HI1t7zfxQ zyB3y461u3VGXBeGJ$Gi+&9k>O4j3>3Ar%ca)P9<4$&XcaBG4v?8Gtf`d1#F5M=mUG zQO!a+W=Y}iwg=?>r?j#}hb4i;f+;HB<0DA)eIWE^dcWiPd;TTK;FZRMU+%>~LQ^J2 z0;Hqdxo7Q0sxYmx=G-GkhoA4YaE-1EVuX12T|MFXJt+RSLvP7o`$@j6n_ z^?ya#-knvR&+lx-u6EE}f*EmQjl64j8XcGKa&^gZFUAIswQa~ zIUGvP7Ce2p&DKR-N!NSt_5X(bO0Lz6)k|`ZGfFO&P5F|1d!ZTxeRx~WYHEOJRIsG= zk$ymFU`765#?^zF#lZ=gjupF#rt8ff^**KuRG&M&s*Mhup4f*M_wJ|NN-^!L2%W>Y zj)S`|f8K^Hly6hO7x_QlIYZXXqz~wWFb@{3qFqJaxJA>`gB7VG#~!=THuxsY3I2?b zyX{FIp31@91V{kwl$ZioD<^-MT)WuaZa%Vv(jos_?kfl^v0nS;hkOwJMxoP|oLdlp z5zC@5VPX-;OOJ+4O%-*&hz2(M9~O(@y3>|W>Xo}pUo9bZ0|rT+`OKs*^hTtq%Xvpz z#|rf;)rz+LZP^7KAp_Y*lPiI;Rc|!0oyk-|lKb=Bmw*EXchmbnaREeDFD*^o>HLdB zF-mbo_|m+kkMC-?qvH#*xgYVkhdY+Cu`!RbhVXaGZEGTMv(m>{bJ%Mi{qfINR%5&6 zP*1{B;l%r~^R@WUAYerXvrPD9<=cmIu>jen_+!4;y|dt#ly4edqZ-u|&?vlvc3Xju zOL_Uz_0M(>=fyWZG*1bpc-<~!D|GvO(E10LU&H0L%Q$7=C9xY~8M&Y#%r2BkbxRVs zUY(*Ev!zRJV>IAeG|f$#>~*P2vT^L%G@sXm9>%C>){4gG7^AERUan2UvBx(-qWX1z zS2J}QpPt>Hkyz)lZ`_WBdu#mCf$9L7LYsU?4H=j>-1TjM_sy1-6f#jMw|js%_ zQMD}gkSl0YfpCS#h3Tg)=rq|~-CRFnyx4^h;<3&0?{*G7jLDGSc8z~F44>eJrE2vn z;q^3)U@*&nzIhQH^V)unQ%rs$CA)q7DT zZ@P{-eF7D ze^G6WP3^tr1P#MU`?vE)nA&j3k3XYBsaC8HIGHIx{=9hZ$MLD|rHPsDM+57HO>oeD zEoQ;4o-8>0`xni{O~3n+POf?%_xjXm|HGXQG857i-d4#qoi4K-pCWA6z2#ay$3>yM z+5W}V=}$68GxdQ}HSag&MtZEJyxRUmh-XEPvD$=jh+N1$k4~fg!@j}AGQT#h*t}kN zT&B9HI^w(+>Ci+s(Mvfd9gZmr>HB{)Q4GyB{p;ka+^!$}OfI;r3X{wO6q- z*S&Zmz3VqE{R9Io1+0X%=pldBy7|LSE-WlMO1-??dSgwobIV3FGgfWM+-vo!J0;dG z%-fmDpW^wm^f5sdI(dH8k7(MJ&#m4r$vz<2*pu{CutWL>o!sybahXRval89Ndikei zyIm+=589Z;ay&V`d2hSU20z{XX{Nu?{Ph==ZKYm$egmHdQ%%BaQ&ztJgoK6|fub`A z6>338$7#jb4_Vdr!bpBarhL^u{g|G2i9xFCo$)A1l-c;i>WWq}gzC@l&qpW$?^REP zdj~}-?LPkBz@~W9SP_1%z|!1(JroMKGI^>4pX|RMw>lh_p!_3-XJp;QMs1!})#1xz zJQlDP-6{{j3(0pf&Ol9xP~UTHD!2%96ar)7XR~+d+nZhbGIK!0gs^<#qGnf3L@AW) z{zah(MlwjQc&YGgN5Q?6>UMHPqSOj;(iFYQ$fQ*eE{zel%$cvRW9O@^-BCA_&`-VC zc$`$(>tm^1)F8TjJ_eoa#YF$)>zpyr3)SoPyQkQcxgYy>r~Py*THJTR{=jIzFo)p zq4g5|Bq!aUXeQcY@FTfL?be$Eq4(VW(+PXLwLIh(MkC3uBypddFAwhu+~*mvPQDn& z4_A6a3q0b}ygwFF8BD7{T)Y>pHyTukdK#agX~;CyJN$ zjZdMvX#ie&X@1kg(B?@wTAN5NutfKPx0?3t;&Kz2A1WDT>ymheWpn;FHumM_p8#C- zn6oU6PY$P0FheVxdy(pXv$bx_w^+~EZ6!UPh1nyUa9)#{>4&vPm=brs72zLZMYgc`9#%iQb^8E+>34~0WVAma_>uNFpVxq;iuhH! z`aQduGg{0K)}&>|b!zRU8H@ZUBu72HTIuI0j}?MjA3X6iT&KG;+3wFQ2at&D6yCelW43;89()AyZ+`yxti=lD ztJ8CD2NJ#*G@ZEqhWpG44sJ2wxZNpzQZ#OKp}$nx*n2szOzkfAt8uvXko_rsYRSpS z{JpI9pu z6!+4QKF@0t`N=V6P^Py1_CqoDWL>40IBL+?wQ*;Nr^h~bZYcEO)sYZ_+<=+5D z3L0z#8*F7>zSbFGs%p^NVMw#K;Y?m8ZFt)0)GkgP<4VToR$x@Y4?U_C4(UW~!!NI}%&(M>yK-u?<)hZ?TrN5*H0dsEcI)&5mzM;(Hn-*vznAw}N88!F z__LYW`7Oi2p!6|QST=mZ(b)YXvPrk_3MDrEB4a^xa>$0OA3$+lExm#3Qo?cAliqoMc((@`Jdryg4iIi2-w!QP$nX zp^X@udT6{&!d?39lH7WKay|_IoERGvk2e0IMy;z?ag1dc6PdUQQo9UvOzLX|!#=&N zdscQve4Bf8!?*Tt=SBb997dr63b?H{&j&~e2g!lWwta2SK@EFWFmgU;>NxOceNnMS zd{uj0W3 z`t@tY3Gaqdsy7?}n9GB5bLM49GzEh@k2!VB4hiq+%9Brngx8eG$m36dr$VqQFdu`wDUM>og}9rZO2HEgnu!n458!b zZ2g7DiwSrF7fIvU-Ph63vuxngGyJtA_^za=`E2vIt3@0EXI);tU+1ycDc1Xh&de`| z3Fd;q783)LZ=%|F^bEw#=~jOCeyMwNs+O0_e(!ne)mfsS%=%S~YV&hbKig5L04Pyb z)V_kSY-o7u@~!n< zOCilAo=fCx1gs~1tm&xw-gGPbyz3rmpDye>93S~2@UrA0DmaLmt9meJ10KP(>GIWk(z95#)sU7PaII=||`A>lHu zhA8ix_a426&e5&gVe;{G7EQ$Kupg55aGFOYOwr&{>mPNKo>j_~>_0W4G!VE$GWuO& zNMeohoy~QhI#-?CBJ6%TEIHV~V6r3GMD&k0*yIiXW{FxJ6V&2WkJIUnnIznG;jtaK zgkA%DS+I$|&f}1-#Sw$pmM){pUBE289nQU7cw1t9?43;Au$-rrhQmSEcb~tq)g!(O zksgTe>f;4rE~~1vuJ_~!l{`Gw;JjaZ6Q5$J*G8R{QgjEeT_?1mz|LJfXD)w?K|q|% zU_*-*z?0_1OvZApBtLQgeqjzb{5y-Dr_DlLXP8Ma$_mt zbKfX&cO}rrwnXOdvYo>$L;h!E1{<-&#DpdZeaoci^=a3Jtg!y#Zmwte)t!sQ<0{Gj;LL zz)&LF0qF9WW*W4sBXcTf54JdQf)8bA(w~4S$bdkM+R8)To-YEI6hiKL1+u~r?2p`w zcp#{ciMRh6F-MrMBK$UpV;zg zG^KpHzjtSK%RrBBzIA+z_6@!d@?3W%L_s6GyQSm>3^Bwr`jnaMq_`OpV?4#O{a703+_d5y)nBjNav)${Wm|UMQMN{uO26 zEb6#=0DQh~=ixJ020oQkK>-oBQeOs3+}WHR>T^O)Zj%`8+SJqOGZz5I^*hV?OftmM zn`yU()!Cww=j*r2p8BLaEP~|_-t)x3-i!+9p-Y6!-&M!!}Okg-A~6n?6^qIb}dPHW6~F!wVZ7wt+(uD4K0JLE|bRl`|A#8 zc^rB-m3$q1FyOaMrY);g5eaY6I{cw+()EJ__u)OZPLbE!mV#nb2KR(xJ6+yKdiZXp z9!z^u^;2k&b(lTan6>DYP|RvDu+=8^96t4m@UIc*SI~8>J67Ag_bW*|-FXDTQ*5H1)k&R!o3qfg0Gw#< z&nJBhsH|Ckd?w<{#Q6!^U=S^rODR8w&-_|lhR)CbH%fq`&!DHLhl+woM|Wx!%GS_al z=d50)a26+^8@qj6lp3p}KI=YSJ109Az}=EieDVB7%Pku^Z(RQ|ch1v(ntIN4QkB>U zFg@AHqWRy9t%B&FN8f4WmEZpRm@zl^ePw9AB*PIJ1d>z~^3@bN{Sfg4&66p0&zj3I zOUq2UP5Kq|9aeLPYHs49-QoQ*`tqT-igvHUe;3_wE92wu={Q4h!iXEYACF$f0Iuf4 z(OH`jzAsUR9-|=1kX!!RS7ltkHxks8Zwf9uIdgK;h0a<&K%U6Tp+w&hc=cqTB-vlO ziYc$4B|>xw(Gnf{!02t@krrrRP3#NF%tLx=mN1#=859Pwovjsk-AetmpH+$?`Sn`_bJ+>DF?kS*UKvoxx{V=GPd#=&`tfLc1#L_qAP85#2PxgQ zW8*u|$NFt;kg^UR#{+}YXT!mGf+fMS6h+2p-txLNrZZn+z7>1pS{s|1DA}MS+#Zmz zSL`e9Mq%d&HNE`OyzN|U+A2dxw_MPcrSG&AcpoIpooy%-M?vL@Y(#jwkzp{soY~@b93;t2z0R z8V*I4k$rc5-Z>C#$cpPl_z?P{2#1E0J*P*zMZ>Fhc6ySTrh4&>G?vHhzzgW70+zeO*Og}5y1^L|D}|iWLDSR0u<(&&F8z`vL2&vNB1uJ0sm4u z?)Utr7pd{|?+J}u(bry;(|lsK&TsF4nohvqY0dgE(vN_Wi)1uUkLYb-y z)1nm@c}wm}{;quU<=hUT5*CBFZ9=Qvcb2mX94pNbR@|$>MnKs4(@r}s`x>U6R#ZB% zq_D38IOXtS?V~!^L?_i~mboiWDv+(Z@-?j-!}@@U@jZ>Icdv@-QeQlocXIk?XBu(* z`mDUpHp$NHU5A0bvpc7;_3V;-m2>5Hu86R&G%a$O&7>0w`k1==DML+O4SmVVV328E z6*;>MRZ7Gp(|~4=lK~JA~d|W zvzHt9yfzstZcR`+TTfNAUTnUPT{n8eooKBPN^rQ@q?dki={H473`U8z3KIqX@ zO7b$|^TYp?Vidp*{gm~@WKGBUQ#%2aI~D^Wf@A|trs;aTX?Nl$6(S| zn-I$-G*Y8>UO0s)6k^oS@9@oXD;^!h3Lp3&2xj-%Iw(E>7GL;VKHGK~8-0;g?s7ea z?UaCb(G`h%9^<1)6Tg&;U;cy&@LGNwP+zhPS2pqzVTp=i_U?6;XomxQz3HV>WU@Z9Rz%|BJPe%VE~nO|p8X*K} zVfaq}+}}^HRo!}PP2PlN-CcX$846*Z6&uc~?-XkNSwu|%^|Z&nwp*g{<7-Q}Rd;lp zaw&?H6be%NF4v;{PT9rM1qLoFC?tOJtaPiFmZCP``ZyqUgave<;EMw{_Kms@a_zHh%$5(quRnNC-=l}DCMVc&Tz($ zJ@fXZd+qh3@Z83Mjr_Z|a9%W-{W$EgX_W=px!Yw<@xGE|A^(RzuX!!os7&qD$BvgJ zL8w-Z)s#G7TV{4{`Jg`>M;>0fNKFx#RI@M9FOZY@JeDFxDH~<-6DzvT;ONnhxK$QLu(47ZH)t)qQWp5HqGvtPS2wx#TvDMK! z`Pg2l#rfI(GY>xKUJg>2@iwgpT^;exj0*(4AoYWrQ!S2_g$@bjMmTZmWu|p?W?wuI z*&csGpi5}V=fQJk95^`JDz?@*H!NYyGB%`n>V-|EP&VQfI2R~kMRo~6@n!SGykl!| zNew={4kDewF}<7PNM$b(?Cbf4NeY7+-vhNhK2WqjG+Gz zy}x_L*03pSGm2Y`fSIys(n7{ZN#!|Sxo_a^$Mzj>c4oGIb*qRpWJrAe<6@v%^7GWS z1?jRU_XNlDJS{Lei`4w;_pSBb2kQ=GtsQdQ7iJng&6_VNfju)TIkaIyp)UGP!iZky zgw)**-*7<1{X+iUOt^~BDqoFfvN(kZq-rq5WT(#7QJg3jMpjRqomTI$I$g{x z=eWM$qa4q;>oOCS=AQ3f<`*u={r)*9h9SA!-f7Hatb#;v9 zm4Rhw*(ELo=p4BB-g@v*59-zdXzJ{=E2_y69~T#>{(13gpGN9=HHJUwjki{Ju(Ppp zUy4?!Y79-H3ss|4WKlc!a#bkpLXQfQOjm_g?RVv~2)@b}|0Fs%4`K zg^fmRMd$9o@|+dsR-XWkli^}jW5Yi{J(wOe;>-Cy%xSq!I^A66Ey_50Zv63BMTS`Vth?f~J8tQ( z+wKS%-nbV^pQvo(o_5anrj>$SVIk{UF+;k~H1zbN^f>Bs+{7L0LPd_Kg<7xwz=JCE z?Q19Irn(v~@1TzCR}6$i`5cR5U%8hrhq;>j!_4Cfd%;mZsBXKPlh7ACS6&c$cI8n9 zck*_}Yi+}{Rn`xzWtw_9t7$kbR&qY_e*0qtfZrSlZ{AF|g%5JT z63PdW%d&B+S zyVXtg_iMA#9<8*rnG!dOQ*=>Ot!k%+|u;eMrN_F z)w_MqDF14q|$C1Epyg4_QH$O{42roDO9e#HA zUowLNpFJn_s6U<$il|BCuUgw@rn!33l{>fs@DrBEEK&a>7^imn@CVzeHm+tR zp5>>?$oow``=*z({%6ZKJImy+WnUif={?au?gZdVIi@TK)h=R>%L>GFPQWltqwR+D zouEq!fQogNUmvgG=4OhHrnwn14}ItG-n_wYTx&Wkb8?9+wZra-?Dpv61vi|hA2u?p)b>%8#@B^! zGoNA#|G4|R=#R?nTs?JBCZ9fki@o=Kiml+i*yE-{dxmlq6i0J%qt{WMl;rXZilvG* ziO=<3m`-R?B{*93kDbZBR?7+7H87Yn=;23SzGLPdzmbL@=1?yy_|Wl`wforU2u`OV zV|(rxKidCjPXvViVbg<=lUf8Vwd=Dag}$g?^3P)#BO$W zUDti#itqN#Ga6B)CT=fN{O8)c6MFK$deBggcO1tLfN@cbTTs7lOIV4?T2VU%8)HGG z^GcCCUSf7E?d)pfvUmQopx5x_RW=wc5P(9^ufA9BmrR?7ZQJV(#auhrxX(|&okrRn zymsfyuj_xr=-=QFqF;i^MA9P(uL}-h@VMQ^6bk^1CF#9N7 zLj4EqynvRA6JK=*zK;w-N4YFN#W>P&VCP_Sck-xH#p(G<;MD{^1Q~+MjsPWSyS_Z4&fsxRs!}zq&=J@$JcLW(0XA9`4mT$|iv3@iV*K??2-J`y^Fdmco zT>8tOBQj%)8BF|W*EVc@#GpK@!Y|wGT$%y*GSA;#1G+7FUzMh#6BLPG(vXYZyXYeh zt23)!1SOd zKTjr99nr&28YYBvpOm(kTnrkB@>rpuc*ChO(3@ESLm%fxlan?EQmKY*k$C!^mYJ3l z7R-!WRQ^VNhb%!LWjnK!yQr^YO}m8@w9PPYE5vYoIIRpPhktE|=(yI_?R zOk?UtWY!C`KVxOyGPXdLKSkblM*{vuD+dFv~WhP?6rG zfJc?)dgnG9=+p;G@DvDIw(sIA{gh;CrPJ#leME=pO2PNs)ca?3=G+qemx~J!788Q16aK5e3Z53&12i}$E;+z8+PQc*C zb4(9xv+q(Mtj--Hslv<5nN-nic1e8dEb38O^c)8MNYWo?M zbKOyOFbErQxzKsyRl3ZGj$5!{+q0|>FC9P7%AzulTIp$Hwy0)X-MBe9!0Ev%PJXL$ zT(9B1Yny+)(pf#7vBGmSzhw`-lm0YtX@KLsW>lk_Li6ksy)!oM!}y}W1hLCEg z-LO^r0tR_VBlG+uJL)1C`Ds|=M`^epzyX~!2oeC8jg26ni#9E(LFD;d{4=|rYp3W3 zeX0L%CY?~Lej)*486&i(GTrreDk8?_aQ`v|3a`@1^kb<_#_}(|pWs5=O!s!JsieN{ z`&>bxESajOIVo9^iftmX<)KEA>vTd4tcQ~XqwXrdeEwi5Usd7`s;UKMT}3cDN!O}{ zzukqI9zsBTdzJyyujn_``}z&G&VWJO_r95QfkA%tp}z;*)rdnZ6rY{Z8XMG)%kDpN z^TJgf5o>*(REG3B1G)YGbcLzqY3b4H#TZthTUN)od2`qP&?~9DT>`bCM@}t{lJ@TC zHTK%)v%k?47AOsZS|8qq8nK&-G&%%*e0JlR=^zuC0+GPb8uy^BZ>(|rwtsWlA#0Dg zQrRfH>d!rRhmVYZu(Efh`1DH0&|9B{c|u}a_Pxah1zQCeR-|pNdOFTaeq0Pg6>wF- z#?)8oFs-5qY|lbV9Ql##bIF#af3fcE>5l|=1;6)G3sPQkmG=7lTh?}iXk=k!#V;nt zygPD5j6lA-n(f}S7HSG`e#c~1B?W>bi)lf7P}nr>lAFFzDCpjIuKI7|%4!RV>7PIO z1O(_;E*p{GwDeZza%r2lwCwgYWlzWFXW?QP%JIft0%cs z(CypVQS4ZGjmDdITi20Gfg?{#8AmOA#;h`hG3Rov^2>xfaEw5-@bmMNneWljQGq7$ zYfm)WL=qWTU-eNr#vOQ!?qK|4_XDwtd)zDbi0jzR`L8K-udXf!7HDBn5vH}6FsMz( zEp$V?JmXQ>21l{^A`6>@!58@$tj2(k?@YJx%FC}OLxQ%pqTuFg;AoI&xLYk~?#qAV zRn_9KWloupdav|}O*&_No_{|*Kewd4tGBNYEb=RUetQ$OC0h6Qo<8)+z_bHb~0WQ_W)rLW9V3(Irl2W!poWZ{dec+Hb15 z<*{1sU2iLo>ij2cUIAK~G`e~PHSFB$|BiC7hO9OH`=~=cnqf&*J@GT&!4zM$I8W6} zH~g)8|NLLKy-dC!(zU69;kbW){RkBHW{JP8^1X+E8>Z()+Y*>kVG z?EdlUrIcy6^os-vr~@jqv!du-L(iGnSu*k93KA7sD``tw&XM#9g_OF=paE)3(d)9*jsZaJj0t}Ax^U!8R>NIt;6X~gSw=jf918lA{U?*>CtYl zYa##e;lp?`ZK#YXb^L)6K^mv2uQVJ0HE8ym^MbY9V zC{kCKs2TYFp>n7SY5oSB;v8y1W-~yQz}JPDOlgr(PJrfzhOH2n>gS)^N+wlF9QyDG zh4O!cDU+LwRpYEe{_z}!uWZt(aX1G$&UU81O+~$;sQW;ts(jIUZOxBaUiHApt7t1i z7n+gmc1A1@%8`E4d*ZG75g6_!P57?;VOHv8`<~{XUfsE!Yx#-#dzY|KaDtH`G%_I) z(u41E9srNF9DuDZW2u>==Z{6@wb^Ohc7#Mmx16i={2c%YcDXH&{e{O@p*+KM_4fKbGg30o^LkS`5PC(DdFW^-P|6c zYUT>NE|9wQ-sA*V_Y#c#%F{Gnn5>8e{F{&;06_C^Z$x(jBL(ySQCnN-DTpWcqG^cW zK8*^T8yr@p*|2r^{5{k$-=K@mt6=jx-S@VDlQ7>0zta@K#o3YY>A|li@s=oRa_60= zDM%ZsEMP8Xymn-4DzC4%_ukL#NzFFga-M6*GCV)k?uT(!cFqUk;f9Y&V&RjOlj9PS zQSlyQvl#awLssJEm8N(vudJo61_r*sgmf@4zIR+rNRKkmQwPL~u8;=+0}*15LLj$T zR7(UtAI`y%LL4vF>~R*(J}N>MIOtm4_|$~beexn$>Qz8!j<%g`+^~5w9jUXsecuSC zj-BAv#MjoZa_me9{Nu3vFDWuYh3}dG-Y{AveeN@h4fe^`DIhIl5u=fV0_@}S+~6R~ z|DNUtV5RuaZ;z&2L_M*X$^*>h-nQdQFwsYdwPbun+TjxofH5@P!=2&EeU?nYbruyD zWjgZCZozdj)^(Yg!=MG6M>10Q|Iqcb}=^t+~b=^D|Hh3D9{4 z=^oC2|KVh)|8CXH?DxBmn8j-JpU2XhbMs(^s2cS zsw;#E)C;_UyzXg=Vcc=olf;YE>InIaD*-nE%L8$h5i=elsQ?zi^9Goqz}-BUJ_h^z zFPP4E+FsNGzZn?!2cafoM5Zwi2XHXZO$~*g^CqquTEBoVoq>V5W*}#P1`w5y@C3M9 zwJ0wS;92%e_aTD_y?kiE7Rbp7aqx1@Rl;9y!lJZndW4tZYX*lSyyqkcCMGMFFb+fo zG(HAY5D9Oe6SGKX4mFs)@M-|T{deYfF{@XnF|~k&Lug+JhQG`NS21)z;FHk!BSbgMbWzWR z!eBWB2aH)5!mHo$z#g1>6chxML-7<4Ix$M|!tJ%Va~gU8VM1IP zpqEAP09uzB!QmLz9|=ry&KJahgqQN#ZiG9a)7pdtLMy?U1Fs1XVm<+M$x7&&Gwd|z zIs+h8?qX>eDMq7U4aY|{YU^RToXr4w;f`}%o1jIpOB56oy=jUkU_#Rz>e}B8$NMsm z)2y8jOzXU`Feu0lp8;w^BObK9z@RR-!TXdiN4YKi4ZBBc{7}J{(A<6Id3JgN9BqBD zT?1T}s^(cgds`B`L3WyRYjfT=F6o}~Wz?+#s5_K0H@^H)T>JRjoih6Ee|*&P|HD~i zuD;%Mk5kX$(sGS224pcq=wFs=Zr;u1n<~0g(}{)L5aOvmULggJPkQm(rLq!OMc_g^^mF5GGB{YG)v6dwhTG?gWF2@S0~{DV#Hf zBkvQy9D;#2ip+fAi-wr$M%W0$V1J@q8n-@(O4DQq7$EZu99)VJ1=d^bYKNwK!7-@_ z%!cc8J;^f9U`j%#r{iI8Xy}>a3kKxW)CkrdW)%<+fbjZ3}=fdr4uu`xnZ-1VocY45PDa$Kdx@gDG+wu9c)0a)@Ki^K* zg6_RI-5`S^VuON$WO?HehZ-?cZt>HoSst$+ACaP&Ohgp;)Z9D^X4{Xb8Uo;zHqHh5o=I)NiF8;}D?WXt%0xGsiD?&YzfnXgqWTdB{(9WDY4WN=QP_h--4soJk zx!=q0Vyw*Rh3Fk7t&(26uoCEgbGW~I)5NjUTev~1DLSdl?6iXaQfg>QU}(uRW!}FW z5cwzFo@vE1|8YQqzf(_k{I1aU>^Z4?#JHaKLkZ$RT+(Z3BQNDJu>SYbY{l|UdRqxi zP3&fN?UY>Zc2gx)wu4oR`5s?u0|7Lc85kE=9HxyQ0)OR|{Y!pFlK7_>tc!lY8v=Dd zztvfA*3tn zOtKcLA>dw1VPw%k0il`$U*Jrq1TOv+AQf!F@h^Ab`7L<2DuHVNW(S?Gd7O|O_W-s> zsJJNMnpy<*Iu9=|qu9$nfN@2{A<4bPx5#y^t*wo??Qk`l1NTv;Gomj6xMUNU7pP_f z?88T>RyIHqeG8OyDE2WJbbcY$3%Gtm0ZlRsY0N-rk_GoW6a+!=%nfc=zrF=!;UbLk zGI`8kg)?5JMrQ#qc@gHo_t1;mteydQRL*0YIeP+eS|-3qgGCzxsI9S+qM4Z>|BL)Z za-BLQ%H#KI^e$5HJE8OhE{qu7VTJwZBUrX+fKEV6slAn)LCPuFHOBu#g8)L=c`Z)2 zD;^*mf4LT)c~o$TpgUk@b2SUQ5$1o3ptQmuo*Wdi{Ll-i{j-p0JyGeL@axY7 zyxrTz8QQnHlO!lv)#{9i$E{A%433SSb-- z0|5XS^p!i@^oEoffRk3)AgV7u5cc5j$l-{)0v5;3cCJD+LvZ?n8h^w)ovc&na02aZ z26pa1ppyaC9@XE#co`K|Pt`OBalxR{i3O2M;iXjIjfV#t8zzVNLDDwN+$~H8sn2<_ zmjzWS$UcwCH>5p4EGwFQhxhBXmv^o&(aKCs@P&&2=zt(Fa1i9B=<7ZK@a+Gfvljs=k@oJQT*gR_&P7Zk z3}`Wy_W;EmGjA=sC+BnaB9-ctNN=H@oM=;%7Qw95L8GeU3e;Xe3Ib@F`;Kn*D}lH{ zOC3j-(q%8l^Dq|+8SU|C@bIBL>{nwMA6wUo|lTR z0XzOg%-mW`$9(_i{T<0&**kaK@BR{vPdxn5d3M&hvG3a43$wC+e5bU76lcG!<5ITq z`M=-tVgm5tC0gfp=&{afEn+zZ{Y^h0*NV0j{U9*vCh=GC@ydi*u#3;u(L-2`#Np@OhL&IBqzkkdrhrT=g{Pl;h~kq|7H{mk)?!T9*;g%a|}ByFUFJkHLBiZ3b9?!Rv=dXz5m zzn?-^_wP@~GV^jzeCD&r-zn>@5rhYU`-XU}4?Le*zyXpRJp$etyJ0bU3mK%D$f+>P zBs#e+CqWT}nw}L&RHy9h#H;_khVyj*2Y}KO{*UftdL5uVtMmSZqpocUZf#{MLaOTO zN>E8_rOM+?H25)NP`tIL%&;k(8_zpz%8Z~cmnZ~8r8VSnp^gn9OJ`&mP*OpeprYc` z^^cD~gY;`Oq6pKLCHs>0D2fs2Tu>mbPF>)GhwKa>o&puUiY%HpTIlbYLu`hMx_Wu> zd2%Y^JLecA{`b~xs^Kv}Q#OG7oKgy4ubiBn1K3m(q*)FZ`~Jq|0a3v}EbIiQUfyHa zaL)oF($*8Wb8z8Fh318mt7|YEjVLx2ww^DNO8;U%W|o$SAjb&ZKj53{vC%p4h6Jni z_W!-T^B4c^x_!1&u<54`+4A(mwd)BrkC8cS_K@E3)Nl%#C+HG_1{?q|E7q~{uSddd z3RQzB*GQ2iQWwo3^zAp?YEiUO?XNLAq;P!`)Y8&I;2`+0O_;8+gwukMnD_&C^IHy)l>eSCJXHMr4RwK)U_F%kBFPtU;ONc@jZyS5o#k5Wl<1IpCF<}L zI(Ce3UAv4yu?L9Ah-639B4Ge7gAvrBesQi|)f4_$IGXSDcu;Z$q7Z^&Qh&5(2ihcX zSyFGe#bEpa;QRsR=gfR|s8FOfymhy|lAvfqOh$uJYppof_Hd=QZ3C+vu^fTDy_$ga zAw~LMWaRZLe}evd6sc5u1_@gMcY&H^*#wv1rvOB=gc-9+_jRTz@1NI8wszO2TR|p8G!c*!I}^o7;ha_g zSL)R`?kDIa0OpXo48flZbFsNGxcV=+28msz_0U$a70WMs=}3sO<729EFYeR%*{H~4 z5<-Rg8yvtL0vIGblfj02Q0C!gXg`5!*Ko1`bS4zL1Cd!pP_d<_r<;Fye+%su^#@zm zvWAO{t6X4)2Oml-XLw^UjUW%96|bZf$DxCdnORu4Q8G5=hgM1~D2 z&LbGEL=k#%5S@ezIb8;lCy$`c#KA}({kAN*SY`&0AS6-41p$I_ijxo&W)8MTbSHsx zc>aN(pC78*0P`E^eQq%!YPN0YcJQaZ2zXuW`oZFyq*!2ywEcLgxcS* zk=&^^0az;92{NuymE_zNvK`K&#hAl`;5HKan3(K3oQx9xopql*GO;6nT>@q(M1e;_ zSY2gJP2zl2P0cyzMnRF!f3xBuC+F4fhp(2}HsKU0*}{-fQ{U15Bfp+Q9%)Kf=$%Ib zMaBI2rB6oe>IK0boF*8}B;ylHpT-WdsWCHXVhk48CdirC7=)XJmkt4t6CAU&bI^4G z(ixcG6D$trCw7rXV^Fpx5IkFch`?(b=91a?4M7h<(}4w^Ec zXVCTx(*1i0F$Q^6*4p_-B0yLE3fnuoS-m91>C>pA03S?nHKX75PxyYy^E;{;nJ_wKDiP|+^c=HLCQ@Uq1xQ-#8qsz5!Tp2`tGKadI6)+eHL`#NC3 zef0z0X!R_YWsIzV^a$D-&(gGIxn?1tz1NjBTvV~}8Tup^14m<mEOt$unw@y_sYUK z@zrR+EnOR?d1`8CWlCx{2;UX?1xaaZP z+S=UmJ3QbxgfeXp)3nJg0%4UEsxj1smd>aY7dE?q_7!0EJcn`dDW_ zbQ2=rs{sO%0Iuqt5B9^qf3wMqegn&(*l@uxcJuHvuX6YEC7+_5*HThajGQ;3fAGc# zWlshShWbbT+fk@s>})3?qBqF^Px0U+jbz8!e0AnU?2lMEDv=uJE^@|5sQ`f=8WFbj z@$vD7RX^f-Z0YUhPtD6c))m0&i{HviEeyKqfmthpg-@D^`i zWb8Rb!5M>>FdsK!TW#~r5(<{rzG|VrIB1h|)#KP?V>J8gZWjzj5N83*UM1$XkhMmv zPHj1Rg{0kx9<Glmz&Yp6`;^M*yZG9A<$R-r+m z^o%J2i6pSv5JxKZYKCG}I7IhA^~;VnEEK^;2J}zpX<0&SUjh@?yGwYuxCXCXrrE$T zf+WPo0G#>IK{a572czZ)6r*qh?Ad03O5`gn3TjtDyiLLKX(Q8?(yZ|dsua4MFPeW5 z2g@Ic&3P7;9}cv(!5UX~p=asD}I3 z{1<%NOzvztOw+PkF0*v#CJ_AOR-;mA#S1v4=UYbd+>F>TA?jxGgJpF`w;Z`lYsFFI z5rTt5QveytGUp48s~*7)kqTNHNo@}a{Fs_A27EB5(N>43D*7lRHTCrM%|J7!? zn%3%h*yaxyI zlw|t9%?kPZL{U+tsgu&!?!F(h5g1&<(#lfy)$Qnn;k3Vrge5)F9mPLT?+O~W-p7Z# zTZA>xa!Z2^8s!+0F*7II?I32|6N>ZiZ3ZqtGDr{iJ=lye{^(*}l?>nYEc~j{YslL~ zeLy!4t|uaiE<1n_ZUx!>zL8H298l%DL6 z$?RN;qBAJX9GRD%JXz2%z5w`7P^_|bb12&5`S`LzgMvPLYJ`P{<6^+!nfQXhAN*xw z-1j(jvS~>gi$od^4KT6-fmp98C}uN$eg6IwPARQ7(UklT_Nn(a?%g7^#?z{$Ai>HH zmvP@Om(9aMf!>gMLyj6Cr4vv?M+u6Mi0fCz+xf%3VTyAD-$ki#JAIPn;k&n?(_QOghd zFy_Nk6NA9#KY-WiCGZf@h6Gi5;=HnwQkvFVCB2a+Y-#czgg*#K;qe~C$fou;e_MxQx8QkCf%_c%e0hq#iJ}fj zd_(IAFb?_!LuljuQt5n-^G0qEv`6XIYk!^sSuzymBf{?=g@s*<+s_8*qY)e-uoWAA z^MSrK(y5?wG($QgT***;16Vs_>AQ;1D~>7zXpo4YwiXl{hTVw%3Thdu))rJv9<#^4 z9CpGLZ$}zIwtw%9E^kjVpCx=x9dfTbs;qOySA!>@-S{4msF=phxmHxOKslugwXwgm zZ4O*du=yizlfbUeO?iBJ!Z<@jTu;04%UbJB9tduVEs5KAeORJvs;nYvHPkWu#sna z6FKsHJL0wCQp{-z#jwNI!)2Db}%tq9M?Qd0=TP#cdcGt>n-kZH4=s z-KNQbAtZ`I`cSHNPH|nnjY-cBd%IWi74T#0{k;FSLC&xv|L2?4dQ+%yPBA=`^!ihG zjL~SQmEfJ{`1KQxn|IUA9LJkuX%pJ)1w1S>M=Si|@>-Wdg`Be~lffRTv8-Jk$3y4^;h5)%Zw zlqKX}8LEH3!TyQ6tEj-)%ziOy{dN z(&Xniv*~wv-lo)7YcM&_)mu?Q5!ID)ZvF%umKZ_qx$T$Ry1oVGm^UtRrlte!@AhY0 zEY4%*baUI9PUoG8WlKM}M-mg*(ym%?IqQSfvA(~@CE=^L89|4wyd)J-n5BXVijYLF zHY=-FOOB9QPCH0A!NkfjT5{=+8=~KOhYDacs+yu;bo5w|=WD6mB`G||@zyq(3j`Ye zPCU^|7jmv2ld7ibd`mXj?GAw-9-py|7rDbIiBtA>TlAZUAK!Ow#>zjD8OK4I1Ap@O zy#+v_QKu5hNz#4!ZAC?pl|HEDT$uN++E2!2jHE>3@|L4V!4F);tva<1TFDOad5OEF^#(;r-0gVN62{!aMaUT5ANFFQ(y-QSB3 zYsd7%Oc|pS&n#X0!??Te|B9(l?Frue7_F1rFC!ton=5Ucl=nMUW$rCZXJg+e<0F912Z*+IYs60bpOn3wH@8y znIqwAXWrM=W6tutzuEoF^SiL|(VrF%h8J8h$%*(X!;yDPj8@~%4mvuGZ;7%gTr!U2 z2~EGI)c-RX`zb4#o|ohYLs(D}ZNlvmt6$Ajmic!R_D{`_D9*H})NZlwkYgG_r0x;< zPo(yK?wn4fEUtY^|C6Km({NG}gV{VSF}$nf#Oxi!s)?#1%Bt*r^`nBQZr%OqZv4*QrG7Di0eR>Q9O25<7}4 zcWU{JU{;obMqf8^g%ObH~7xlr>%q{UTZeG>tT}wXERn=t} z=W_vPL-D$!4?azQs>IRh(ua2yqFv0?htIqV`ZlwvQaDKM&%bX%QE)y5NMuh3Lp4og z2zRGLSc+CiPeqIkG2FzzF>wncTWuzlF6wc9-@${s&gCXs2w%v4p#n?BMFrQN*-85B z(tN9jo&75*wGZjq=>;?=9|!E^Z3_9HODaJ&)~CNqJs!6Ty}JK@8h(oi!xAma`A>2+ z592nEHNHz7?Iu*tvg4NS*ZaF2PjH7-Gwe|>Q5_NZBq(`m7ES2KVz*zAu)WWOXPxow z;fBGnAmye$7G}wc?p#rayZ_k9YNI;tb{WoNnMw!7lRHCoQ9(=v@VWNBy0MiVV~0@O zSiO|FHB40b&uD7T^2|C)UQ>Uv(N~$pja+PR&0e2xO~!iG6F(izoT?U%Q`;eD zR*f}yrd+{u{l{|)I==K$5Bi)P0_>|@bGk&;bbPN?&79XIC4D`Ko~B>cG@ClPD`To` z@qJy;{<+9fpGf7Jz|ks2429~d3$Wq5g7s8vBpC|yj?^g^dpbI$OS&DEjJnn|n;8R^ zl}KX6p2=P9@e(epO|Ch{!k0?XTu|yd!2_)ej;{k>=a2t?uUlp}e@1Ti5xv+Iyx_0h zHrAv(L3^D_o|i5qS}~h+yfvDqs5@0K-r6)jojjlZje5Oth7r@%)`~?yDx<>jET#G( z-P2MJkGz#s^iRaX;#ex3jqU{=i zwa(mZ{VlTC-EJ(*wzNnzO+DK{&iUYY8H$H_K?^Fl#_l~q=^DYy3j!WTU&!0?rDcr* zn?wVx+7nw-CuK$~9Hi>pG1`gsSVQwBy$8FALk=R}D*#+znU1UBA^Ip%(dpAYy#~?( zEuEC7m+@gs0(GEi#*iAnu$auT9%JrYPv(cW{BHjnb8xz>(d3be%^%(%yC>W$mh4)1 zM3a6&t$cN5B`DZTihI`72yV(7J}!oc%Z52}FSdhp_z2&XZUXhCKMwq!E% z*RNlNNDxTY6jcrp_9*gRlUe+?pW2G=PK9#!D?R8T%pfp4wim8{*F?wCz^(PIR59dT$ zmaGpiMRo4#toxnklEgfd$9`xf@AaGAsp*4_CHBs_h?Vagx+?D`PgJ)%bSWJyh3N-( zY`%yx_y4s)8udU`#n;g&m3c|`o6rqd-p65e%X!* zQ1QCOL?Rt9nj6=Yg->B-jP{tBe+Mzaq}~ec(K=m(*09!pY87umiJ0f>W#O-wVKeXRE{j(+c)1cbX^gb z)N@sUdQXLj?!>H-r2d%JOUfkoWa6^$|N3ECvp*r=;h^9V-o1O59l3jjvY-M7rIa27 zU7NUiuwf(LB-C;rU|0?H?qGVcd%+i#zo5l`1&1}5NfAQ^Sp=@W;53Yt0z0Pk!G;35 z70VMMLmjk8ltE}HF*%0L*ju<(qOJu>4TH|m%lRzm9srsj@;^@^85?dK5WvO=W(3*I zuW^inARFF-8u)MAEEq`(OfG}6(~-uglc@$*bW(mWD?&SQ+Ct`vC+8et>j%LMTP@;7 zxL#M^YWekQqJhwO4Da9!(SzfoQ9B{?hp!rQftpr9i`Jp*e){MI{Ac|td?TF;Cq0P# zZ}&FMRivt%AW?Q%yj`R8o)TAEo(V}WK0{aJU622(cuopu9bw;F#_c-4?Tzah)2$$|lu+e9%Vskb z*;@VRSVZa-5w_aWQ@jLcWNVJORA4Bq%tzDgq2Th5B`s@A3ZE$+N0+j^eqESn;}vZK zwTDEQJ^x{+@$Cz2fv0#c`aVj(Ca?X)K6E(Z!KH^;iq|7Tmxqbi?=b~$EjW9|W_9%y zjJqDa9j+s!5X2VZYMm8AJT`&?3>F zRBgNAZ=c@u3)+6lqJG&?ot@dVO?WzmCnR0n?A&Cbbjl`w50@UfP5hA{a})90e*Ou< z;MhIF+)5f1GAD`6(}Um3&t?o6qz|o`yf4_5gdV(vBD5?2r=`+>r;|_{%fss?Ar>E> zbYrX*=!&uJ9}Yj2)-4a`hk+&JW_k}ugwQrq247Y&Nz^*5SLk7)fj~eIMu9AWB4;38 z0Zfsw%_SSd)v5#`Wud){au}d9=#8SeA!`6V4Mt_fVt9Q>q#D#+5!-@c7G#(QArj36 zpa}VHur?yO8a|yYC>y%D7fnjRH-$vozulLvSHTdshHg~Q(WC;iY3EoUB%y%M6Iy?e z@!swSJ?s;5jRIL%l^>|{Bsar6ZzgxYn}Vyh_S7&(g_Yp-6wQbdgA&zs9UP8`5QQ_p z**45qL*7nxFt;n~Gk?Dq1Z6l|#h))U{#=;p9DA(^37lsq15;G^@k2@0d{q{yqjyRl z5>v%|xTBCgG2`Nm2dl<0aY`1d+P1f>Cxio;mIqd5tKCH@H%GdP1sR(t+-cs(>^>^k z`l;e5LLEXeq-jlvweu&xcP)w2?41jBUWLS;DvdX|BSBc&wzUl24<>&}>$r(~g0i1WieNXQP(J!W#~rP} z8s?arez+z(-J^7=+3&`%o|v3zI;@zD+Ro`s8YNVX5C0VT6YE8o<&Im3n)LdEi+*xm z&7zu9(x>LT<}{~dlJiZjEw%L~nUW@%@?|a6;iT~0An=+V(=?f~q5jPCc%X=L?RBW# z;FWB?b8*ChLKpY7}2+qSpL z!1@fV2H6-kTs*v4P|QU;K#-^b!2vQHVIZZWL3DZVv+o8U z1_ z32H5GR>R@1ainiR>;0P>3PA_hUdzQ~??+&)7&Q?~lg#M<;IX2Lu?aT6I+Ta)G*5?f zJrtWW!pip+1CPf>DEHPGsry*UHwSRB`A>*0bC$~6zMWMeB^0emNE|(>JcMebLj`t< zdvL~{SxY=3wYhQmY*kk)*;>X{X22+_KGquQ@Rv3N=l0(2($3~r-~GunO~$ZBq>|+mSPU*u+NKZFV-!EQT6a4Lj9+=Q z=@r#ak^1y92}aZG_vzfQ@o2poF>Rw-oY~pXN1ZeV2_!r7uU%eg+C=w%J5H(HXBNKe zC&tBxrE=m2=d_C014Rlmov4~%vJsQ_k*%Pj0c;c^V8{NN*Q_`T5+fg%L=5f*y$`7XGIZK)AXJ{fz z!^)Kpnxc5k%$C1cXY;Ph8?8#W>~NJxXiHp7S00%ucgT}85jVe_lBAcVL4CPI`okz& zz8l|SZjjvm?%iI7n^;8Hxs@L`3`X&60BF5_lRo3fgur!+C&eT9(V^_C=aMx%ZzXrw zKW^ODxPCO94fdt9v=!9J1|_cd{Vj!9b}ejEQ&a1{G(6C-t@Kd56x1-Nuiq>UvlJ*= z1@>trbp-f@5h7fuNYKZiI5RLh=j!Dn7vdeGEa)3RUWO6wV=xL4!UuunQdzy`!4k!w ztq9n>&t>Hd*@B1RQ(7lse11=L2*mSfKWneuuQ-A zdGtfG$U?D3X$IvmKnovCwNay`weWYcDk_A~$i`v8#ekg&3%OyyP|Qg{kCg)50@S4i zE5|AmPxinPO79_I5GRAZpu*_RclEiG4Zeqa2^kHb^EtV@w?bG>*Vtcs=TDV4QUsK4ONbrpUaIOL>rgm1G+2&q|C5xPW4C{ z*l_U`W}L^mIE>A;kI}cEX28Q3kCJk)ncGiK{)q79$QEu5PcAG!dP-yLLpHrm%I0?5 zCw_8b!0#T@w3FJ8D!R?WA6G1ThFTJz#b7VJ7-#WqOjgOC+;7|Lb_n=h?53G)bn#j4 z{o@QV$#`7)dz#6nq=X8*Wy7L7Ij&A-GAgG;m)nm?H~9`N{NDFniHYi8R$!~}&{uaE z7r^2@Ht;2A6Z&AjBgu7-A!L6{dgIIQr-ttZc=#DM1CQjdF0!$??%_XStPkpLz}ud} z>itt8YZ1y9{rjs6+qUVmPu@X~&q{L-TqD`@#6^kl%0KMGwwiV{<(k zPMs#udEY+#r(N)+RTk89DnEXUe3~QQMKT^}l(fX@jjsElThHKVggbbWi@ne@tfB1iAuOB35r+VMMT4C;S^;Wt{) zvr~@-?G-d&(1QOg!LbHiCmAv(rbuY*oCRwLnleKj888TwFj#x<)k9<$q|5Q54i_}L zrtO*0zkr9@)T>SW!#VJSLzyTJrVlj#f&3icPgVq)cuY)jeS&G3?TuJ2{ZnXs2+BGP z4bD|C#RNihFdUbC=}(-ObKr=?Kr*sZ-2Nn@S%cM)1X^EHQ&R@D9s&pp0=+s`wr_3K zwCKi-+`@>CjJ7W)n1O}!xt06#onr8kKrHqd=p1~2*yR*p#E_E*K*YH@Y#3;u0Atzc z47(7zYF&+GXBn9sWO@z_(;pBG#%=y85n?lOF|Y>)E1H5&47_Kjz|WTk1|o6a1M&X; ze*3tH$B!G~&pW#_c#zQ?4r_j&y$1(S;p%ljiw+AjF~Obe;eP)vXu zGF^tUaL&99n3KWD@)fFQ{#9K&@_X+-_RLKj!XJc#izKlQ)}1n&Ck~EG3E}uvX>K2H zP|`OoSby;b8?=#>qjmS}0_M(IBRcoRYeUYzHRV?#zJ0$dr(Y*DQSM`&UbE(d>xeMM ze`Sp}>cJC`BMU5b+NBFh{!c|49hp+gjPia}AkURvy4UEPTf?pNelfQmW0K*9%pj{r ze5oqcI2=;6Aa>8VEMA7=F>G$3jWR~BfYj%SEmAZT`yHxY)?c~77OE}>N}5J_S6Y{- zm1wj1PJw3fIQHrrjZ~QKR9&2Pv6m3e8fM98vr};prqs0QD!itmDP4QpJ~I7^Me83G zzT<|ZCbWac$qF{d_)b9|fe%+lC$|*|6+Zw0v z$Ky&tR~j799yW|5Q&il(8m@1w910BMslQNF(@loIJ$#BMt6%qAK!;F;$a&)+Cv4Pp zS9m_~+J@;M14xv#ep$KoGrwc2h7JDs^xp7gTITXo@NwRMZec5yUP==@h_CP_uVQIt zUE$8Tqa>!Q%_jQ9)P*`wOz3>)h$-3x4r93uDKOyZl$n8v>z6`zPr;)~gz;D{8`DcXo9(0t)CmRXw1K%)#ek z*Dqkeq(Czi)d8qUD52s8nUmgNZbV}%zJKP={HZNqqBeMDtV?a%JZU|k6XUlknoNR2^rDIn&kT)p~1=J+5G)#5166U=T9bRXuk|BFRL zWT+tJwa`T6E$^#!_J$pObeW|P?3)VPfQ#Y4pN`M5;43KOe&z{QZ8U0=?7PR|?c<Qk$f^<5g}&1Hr?=y;`~)K9_BU#AYjh-!j7`67!#~SZIPo3 z8$G0rGGPH%#>JGnJbq(k*kU~^0G-gP@H(me^tbNl&P*4nz0;b-mz-P}qpYpyrzWdi_U*>cA50sm*zB8%`=W>E(TED{meF#@@|T# znzKjY{(-XfZdWkte7(Sz_wJ&`k-ULTCEU7V&x|G=U<-;iJ>TiME=#oNUs zJXnLCOaByl!)N8j9P8v?;B}WRBkP0DyQ1mrKkH&)f-PE62|jcSSi!_LJElrT=HjIU z4Iv$~%x9N$c!saxC-Ki0khPR~zP->Ur;kOH6JaWH^je3uAzyskKlmIQUjOp1;Dtdx zP2=NMx2g&6vpmU2hrU%F@^5%TOWMWN^~G^=Ow1z~4ku(mwyoe3XTE}yVPi|s*K7&G zvB9D1516!6@E~LSfylrJGpmA$zqX#q9QlLy1%js?-SXh^9c6SPId!T9%2G6EjRKva zFOS9w#>ySSfPVucg8=`qfO8o<@{hE0I$e^XxW9wN0hlM`6olry%NPV7pJKrUD}fbt z5QJt60mLX^AS+po&kd{zFvNcn-0<;%5g5%p0|CT_0>Kgu8uh_ME7Ck+xEQ`D5>nCB z@8W|?-Z1Tk@J8Umh=BSWO+&@7jfz+d(5jtNH!(Km-cqx>m6@LY#Y9wMwAk{X6sE7} zVdfP)s8=%Q;eiSTbCTuOw>FeJ3O;1#qXSR$%)`u5@6rQ!$mPNBl%yvHka1th3|rGk zr%wA%W)oWG(-kdR5wydO*yapP?)#-XKG?EfE^0Uu6Je!GdQ*6ZGZ0Kj-F-Mtpzm*{ zjLkER>y2}NO6hrLrk&Jr{N{^jWt$X7jTA?|lm{l}lTIt+BTF{M?t+`tPIrWy_DJX0 zRdqDaZ)OTa9Z(LOR9K8j-F-)pZG1MC*S{Z6-s(W#^IXe9wT0WXeJOQH&d`YuWw^q} zhe#vZdXP#>YkUp1cD@muQy5GHK}_;KuFcTpkCunBbJvHnp1$;T8^1)yx@S3e-A2{x zpzwnMtJLV7*CKa)g9X>vgAz$oH09W1{nK%oj=9cHJ%VQ=adjv@Zz0zFt*X(aFPW`= zrGw`g3XWF)3XV4@H@TDfuRL)RJq&%UyT zglDDjwA*pAH41x#Kff|9=b*UEU!Y?UAl+|o{h*u+5MghrzeKDNn|fhhOjP` zuWCLPba&MGKF=-3`%nW+Jr3vqf?D#VeK`{g zhGyb0MWr^u&Te=V=77aE6xqKP%6cA>rMEth%ync0HP^ZFNa_TB#6?QlHNNp<|6u>U z!D=->W)e8zi4!f81jk32k_#_ey}`khWRST87U+?tn@$qM`;8cUBf`8o;V18Y$MUEk z-sbDqC$V~#aQMmuD@5B%)D=t4uN@VdvuiQ##SG=koW13MYgZPjX*Nc*)9m0P09~+{ z57MFa#wEtKu|DZ}OtufLr*W`@$-$UF`EMeLnE16i8umrQp^qcV7!0>&ns6j$z5p2l zIof0nx1TdO9Y9p{%=~;X)Ult(p&(^Fd$y7F_Z%Eoa3qc1;2R)pfa)1N(TL=LJb>T_ zmISF68WfLUD96Ima?=k|9o$w2#um#VKqfK&n&)PkbL>^~l=+CY#zxr^|2eSjU@$;x zDP_02TH~?F3I?u|Vq#*#HvL!LV2i?Y33exCa6h6`6%Ot;tX61$SzX%dNl+Vj0TZD{ z-x9DRGW}lz8vy@(zF~u}l<2*wa4@d3QzlG*-UfvW00bc{3L!Hu^i8oa;Nt;k2x_4r zI~m;bU!#ZAk{^?q_2u+E9P-y)S$m!NmSc3A74tY}mkK;RUr=X{&|4q95}5hDU4FsU zi_Yvq3~nX0Q2l`|4hL!Y_!zL}dS6<}mi=lxp2FoU@AiEEJ(;vK}>Gkx>KGzQ~9Q2r(5S`h+F!Rl*F{MI0Hb(^BaO^Wjoux*if%nE+ zSPnmN5{j7-iZRnnT%|j9-3f5|8Q{c&g^Ey#{m|O{d>bLG#u+U-o1sfkS)bVU-I<}5 z;MBlAb;0YHn!j>i?$2kJeCt2ILmsHKxjQW2l=sAM-QtdF6OK-BcqdWRz9jb6RmAj) z?Sz2!Gw$-wLU<>so)t8~C*)73SH`D%5Vd^?djDU$6Vz7nq+$biSavRDJrkx_%Q&k$ zrq;o(&Gh-4Nmu!eYS~H;!;|%G-#L%)gNEs~+XGt*L!-2rjl#1_cT`y^+Yg9D?|l0j z9GkKmZFCQVYym&ld!o#1x~rT`us{FQvc_NEf+w!T`KcBc)>2ZWljwm+6E{BrMm{y9 zz?J}Md_JQWJdZpHZZm$lzRXV$MHvOEBJ>59OA=G6%bEM@M@^;jn55z(BQ>SS%_PaziuV*NTlrOEL+81qjmaMs?T=6whQZpS1}H!)XihLJf&&hENEk=Z zj}Vd*r3ARRXJP09<-5W?()swv2MtX@Dfl8#yBkdUROsr8&Mi3bK&B0unK2&(4tB=X z7xI9kWQZ+s7f#97ijpVH8oDf$>B_J^Q56UaziJAEDRirjg5w)viNB5R!NHgr?))@; zWu)+?+O_>{;f&1U;mqqI8HS1d!fyfZb_P6R{)L5wlEs6it58Xz!MET0`yC@XV2DFV zz@$sA#4uK^2sd0!Fxi4?-PF^g;mdhxZo5#U;^R~``TPl;nf0mQR-z9y+cl*pe)ILy zEVgqmz8tC!D!(oImCy2h-#doh2kgBb{4;uARlW4E#BVNd3zv6{6&>2*Y*t_fDDU4r z7f~p<6A*VUU<__M{XcB(wJV%;E8LVcok-ZJe!iCBuhohviy$~nNj#T*{bA-$mrY#T z25WBl7eIg!3vDEdH(=@g1cj%mivnOw{ zQ)5&KWoh2jo=x%a82zENH-O7Fchch)|I~vsaRKHx!|lFHByUCTh=gBXk6&(K$zsxw zeFRS*Tpwfd%Zg`eIqv7LiVe`Wi$cFq<(bSGdpv^@BX?o0svV}w^5E)o`SKlkH8`5w z>RigYJtc{+d;D&KX39_U_m?lo-}k}^YNyh)Nj`Yz7X0-{5ftp=Jl>pu=_W4CjlQy#+k(I_3BquafsOijy1*w7{w_c!!K}1N{2+S%3B@=Dm1beQ>#!SGZs~yl<(4LIYF(9~FkCk#@EdBdn zw88yQ>EVrBT_TLrLMLr3=Z$vmbug-m926ilJT4XCY%J5JTog7ET8 zQcwu8fY1yNT-&xD94a+dieAM07_T4&D1=}K$_9c#C^9m0 znxW0~$F>7b6lXxU+E3MvvfhIl4w*y|wg)7pwxt-hRQT^!4=na(m}cm!ala>#52kL6 zKe)5gKxMwOKUr}v+3@mX;@D1Xzl**C_&fG^Rf6;x*2H>ODWV>b!h#YvS^CJ;B0E_$ zlrYlKe0p(0xcV zz9i$1e;Dv(va#8uG2+}a94)u4lckLIokaJJ)o>!0GM3Luc%(T~#z|9Yd*yDL`Ea7< zbvr@4y(1T;tt?G3Ns?8$l9~MUbeFQrD~H7|!T}P|LzIDS4q2ml&-VH1gibbQ?u#i~ z^pK3yF_aE|^fB?vE@XUP<=}`n!%)@66H6zXDr$N$0oPz>iYlt#n=+_|kHk?!2|@=R z?QMHi-lKgXrn7J-YG^<4iz78WUUHj@bd=8u=#C>Ah@0QO(`;cZ<5J;{M0hRfV=LpBGpagQ-onBHhg`b^&JNlRnD`NO3D|1r0vt`OPiTA$mwhk zh4`4I*BK-VoVcEpUhsQXf{)YF z)2yQ11ezeanbeHjE7b=- z^(Lf{SisC68kVh%oP*1aqv&f+a}pO=?59Al`x`Y6^%A6TNqawrcJ(P=UtjCtA8lSV zu+7Wt{}6&J%%57A3nvC<5_~^B1296L0%J44ZeT%V2DqqoXuM|z%pRN{FhWr2eM(YN zvKRd22QXH=Ra)im%=Qg1d*DmIkc@K^wO*os!S`XYJeWgs=>a?Z)m_)8W8sMX<@vfR zd=+KZL!RX$B*Eq}f|(vkEuF*P((}MTjE!~zFubD)6%ZqD!H5x!q{sWoeV}m2L%YO& z;jX)1Px?!%gzxt}PX1mwExY276sNm#(wg~re}i~qU1??f^>4h&(pC=z!ROIxqE&w| zuGM$_EzzcG@^9Gb$<8XB@(|79Z1pC{&E)#*7RACu_KE$cE(wJe@+=kUXi_e?Y*H0 z2Tz|Hlm_2!uU2x8C_KFM_=0K+d5I=#zM~906Q6%#jP=Bt20Dc0EPh*9Mrq}U4)1PG z7w0_+;ozArqy6-`Lu9Ymbz-azql%^Q#Fl#e4y`8d01c6p9-VTLQM8%daMI^rA|{`$ zcJ|MnZL_1j6YsRCjxEpoDwc9I zrw&QM*oIc|E58w=yc@5gSvU1~hbQDTuFvPG*W+1j5?3{RlErESvxWWSsH@?E&|es( zXFiBV2F16Dkctk|;KuddvLrJDEh6jU_bWff<@c{Y}&+ylWFfUOA z2BUB$DY`86*YW@_c@G2OfUP_MXbd+3Up1^MQ0Qi1zEkv?hKdScT^h^wjQ;WAx*v_;Z9+Cx0Ak>R!w5#r^<>}n zeV)I*B94Z9(U@Eo`woD5;JI@K^!8wZ$vDS~7~I>}#oao^i5KDWt)imR2o)0+21o?_ zmy_WsRf1$;ayq)FFuMX|NOBnAJ%4a;umR86tJPg7g<9Y=Led`^3)X~T^!MgiL1bG; z*U_^A3)ccB;t}))gEB{{U3WG&0oh~S+sG;dW4ciC$O9??9dj_pg+_|dPXmC>dwRcM zAY2v}KO$!#*#YJ@(ci(BMxGR9b<1c`_%B?YJ^<0hsy()-y5(IR40Mvw(b+~%rO4s5 za{bQp@DTaw9SA&}nW~=Kx7Fgd9_BNsV&_WT^PUtZK9HIg-|Bml*hNgGSy1H6*m5;L zV5y_} z)_vbq}W zurRQW1AK1*7zR#~_89{_3%~k1TRN;!0Dgyg1FHch(7^|Uk7{L@eGh_C811}ra&k_t zW%tIap?TCeJ&KJflxc}?g5XjmmkxACs9{t*}>srgL9kwyw4DjEa=#FGQH z6RtlLEOW4_A*aP=+r?xIhECL;onZ3~HB>y%nH^7?0x~I(P*VoRN}$UFjGPH>hyHX` zO!zrvA$$gqbR!VF9R37hAX56Q5q0qv^m<%02if@f$p8=37|QJX7>45H2?G#d0{DuE z`e0}{u--bb|2Y!KRu(V{j`~f|)_Nsu^9Z;QvsK^Ul(5r-?gnb{X(VR>DG@$2xChehzsm7!@{^$QW}3!{-=$M_9ZTE6_u|EG zufP~fa{pvBaTmNOsK%1GwJXm~QJ}!7H&x%LekxEt*!cL5vy%PMS9sB$Fz z?aD2O!ocMC;9u*0nqPvc3jy+%A^kSy!H;3hlvwP~j5$Bo@ zd79gu2`qn;y+BR&N?xW;d&J?K^Um=)d7=V`H1GlW|VRt|ly`pk* zf@mo}^&_KY4@}I)6188Nm>2JollrO^$;o!Re|7f7{mkz{?QbuOOSdkyoa1~wcF}M9 z+}AEXmcxB2rb6#D-#1ln%ggm?X{)&|yeN!o`EP~GD^$A&?{K4fYsLx=zF?Qaqf4PRWYmh4^T8PhR-TRvmTMNf2k1#jNodHu3>*79TH*`1xmX=l~? zR~{oSKmL82;pIJP(=S(MEUm5@Dwaqw|0ZtrWWDfWF)gt%rO;{0@Ahru zEVYs6efj&P?nx)1FSEYx`Ixm9&pRL>WP|d{ST5Y6?}D??q6h@Iz#z?LGd?bqCEi4O zW1vE5s#LqGul7SSg-ltwO&e^Tf3@j?pO^}bbd4-BwBmJn03|-a+g4$yVG_wEQGs4C zeBiD+=Uq0ZE&3+uHb<{>53v2UEG>CK{F$ne8>$Ie1hxBing<5kt1Mjl7U=s#Vx8QJ zhh&u%{E%$lO!HQUwgvh`IOE*iirQCm#~b$vi>Vx?61sF;eE1;QB=)9@{`t{i7(3!cqePXjh_Nq(g#%)hN^l?y;TC#& zh^Y@7hOo9nlOPKlLoyS&=7ze)!)D|P6RnG&E3tGB8)^OBVx?4+Rk;vUlFxmmcSllT z74Z^=wL~xp2na~bKJ(lE{(b2=)3)u}l@H+%3xin>!v~_D(|OVHIcpw6I9+Z9OhK8q z5B4vgCbVJpgbfE8B-s$9a#00sc7+aW5`0{C-CTS2#S2Arjgmdke7GCqQ{@4*p>4hQ z32aB`i|vmUyNaqOY@r}mxC|enrlx~T1AFWyF;6DPVRu#N*7+OkT|b<*4J~Hz*naOy z^?b=?|Fvs#>gj=wAa@zI`L@oG@lqo-&gw5XoPpM04U+@)#!pxszO1iT)M?v2@vd5L z|2zmS71&%24G%AbdIY{LFn#uZX;RpAgikOg&PHy;H~w9F|DI(bd*U9~tFA$sXlwoP z1$=Gb#G}`=RTd^rnEs@y?k=$Qv!9f-y!hhwdsKECaLVYG{|zbI zF$ae{ESvB&jbeSfzFuaaU*hDWNVN>F`cR4VpXkrcwf0;6^^oFBi8HS912UX($W%AF zM#gGHPGMqC+WD46$a8sA^UKwvC4u?*QMBc2s%pmBYQE-cbwdlfzJ0s4+;v9r0lRPq zyzu(}{=)V9tHk=BnG@Gz&y}>I zn7)}lJI&h9{LFvecGG|M<)6D4l=nY)4of%|s<`QZ1^X{woZ^4`!~OF>{y%x3pqKw% z4u5`YlK$Ie?q6?RNS}yYuVAnfg$7yiA#m(Eh>by^%K84B{ zuk%52{D0u6*0k5}z6iOlla~~=IV*c{f>V<)%v|fz_x&sx-hR4J&0n)<@nSa8)t@_m z9_{6Nu#Y1I4t!3TD4ZKrEX zdo41@20}iM;|7Rs70l)9G9634t#TXdR$FUbHH};={&v;4apT4T8@}fCSVzk{S~TrU z6a($KMxBGNe<)vs>4C2D>g)@9#_nsGV06GF2razvn(zsNH=HjH6j2!&ezM+1RLL4k z6y@VaCaUtsP_5y|kAM#!bjcJQO)Y5>LK3qY-2YKT%kVLVFJ6nq*#w4dVy5W+nEs z;D5-x)ZW?1;9wF&Yq!q~_Z)(k(*vEP8IYKeQy*C|;3{Fo@W{E>QeyUNtE5x5LD9Rz zWX~?ybrTz&Z~$rGupXMu(CZ*rlX=P0G1LH=*DoWzI!8YehuJVtk;)c)rP##^E z>$aeoZ=km@^rDur#-5;i6|E}BoIU;g{4gQlQP^)37&Wt89Phk4^a^0BC+VEl?J-o~ zd-|neiO!;goBirY4MGMdV}|Et*$Qm|AM*-n*_*#$g(tw>885rFEnlC=I5)zkz*fbI>bIX&pBnFztHHvE=mDfl*_Rwr_A?zgnQ4T z^qeHS&WK2?9vEz~UE#eyXJ+Y!s43?+2d^;DhVPjV40%$X;}_^57B16Cw7K>f^;s4r z6*V=3j`BXQRWdSFakjV5Nl8lXWVol!TrgXi!foirlZlK7?bn5qlUCT*!4iGqDLD7s z1;=q&RV}!h-$Ko4!yV+Uvl&U&t$0VR!~=TpS@FAjS4_?o78I-=7%~x;pN%2PHoZ5a zbNyW9?I+_&<5NsLX*uk(CL7hbq=I$$I1?I5UELe1Mu5yL3@Aruy~Qq5kT&pFCmE~3 z%pS_CH$C`;XZCDQ(4`GPm-a5`bO;+bbL5@6w!L!0H zpF?sdAQR;T{nYqtz-%=kH*W&iv+s;l5W?K^^vtsT+qZLKz(F92%2={Dn$@p(wnVxw z04SwiwWbqpT~aPX5m!^XkItoQrzN|{$`k=(hh}U0ss>%pa9k`CUpCV)^Tc7?`$7w! zj*uG1kX4_wR$@b~x2sb^-3_28MN&gG4}O1SUJu{0DbSd)efHSubWb6%CiRpR)_@)jl(v6&(`ACsME?N3d4xL39 zZ-kYweJDO#o{97aSHsZXT`;HsAg$dW!KsL~RWr2yQrOvFo9y}FoiTBW#@+Ah*K&3e zTXqbjF*^fcO4`k?=H|uM7e2l&U%me|OpG+j6;Sp^YpDoivxvoUV;256sA*mo8u)_T zuZeNMtqq35Yns*qAsYhwze4gx`tpZDTY2#zhiCogXj3|fT2b)3fyiAL)b z$yHHP7s}jC#5GEoHHQDbT>Fc-+$I0I+X7|g%yn=5`7B=5&k@B795__IS7zw ziU^Ihg7=)yrAztXcp~ff1AW=E)6(r|2}(%@sENa6%sD2RyEDk%faeEZyT(6;x5lj8 zF(mnbQ%OZiiHS4~rqj)^Uu@(yY(X+W%VFRQaio#(8tev}PFyF=UjvUb1&_H#3O4{= zfyAJayZ%AF2;Ad59LEd;kjv8W!SPFH=bQEiKBmTjVVB%13NuIInwF!_z^XhVAAUI?kwnz5jOk;D409Vd=&~5cJQ*7T zgp@$!JRRU-5Df%5y>frF7x#@L{J=H8=D#ev`7aUa7bDWsb-Ow9kGf2gr$x$uKQE)!HSob}>+hI!v8iwcQf0Cb176_-7Qo^v7N14>Hw zO(aHIFe2~=wZ4XS6SGpy$i%u~)ixJ6MjF{Bh0K7!=Y~tN4QS00e zfThonn9hC#7skfC!{GpH`=&nXmtnYJ!nnS+9PW^wM1MWh5tzAP)8z{II!PWla6kow zwwF;-xz{yKH!MSXq!VG$0P+3#YupVj&GvS6W%i(XXLl}}PsLFX=v#2QYD6l)n(q~r zjxdFUtYx8uH*(Vn;`9OHFomrl;{h$+A{ZS3{N2}2UB=vo9~0G-n6DL^@I6_VJ_(?g zw>W-Zy?RB1If#`ZDQ;Wak+g9z_@;mR(^4$~(<@9sj$)2M4h;8j;`Tqcm2UQ>NDHG( z`Zgv2Ob^*^p&0`z76xl9I5Q5upudOfF7<9wKM<^dNhApD^ctSDqzXi9PjNLR4RI8D zXo4cb31aWU*xrw#r*9~v0Lb!ib2BGTo+OVg8ajSrgAgcI68@DmPyc^lsFOeD!$a2_ zV9RX$_bSV{-A$sF1Y7Fkiii^^cjIr*0P41PYX0Wnr#+wXptEObyLq8^h5 zcV!-|+u2}1`4MkR&R1!=6}4ra3lWdUc9Ta+ApzoL39Ox0;FBzYzaAVh@6tr8p&!q> z;uS33AZg64x|#%}5f!hhs!C0T>L8(a(FKNVU%dVL$`!eJd8!8w-mSF741rf56G$}T zxiy*xu#X?JMfkwW%LJyF<#WD6euAgI42Djsm}`{}Le~DFo)+^cNllN==1DlH|6@h| z%bWXO$*|ncpn5)}@BjXa%Juv2>GA(rHubNUq~HHX`QiUQPyc_pkT5Z{HD3{HJ|v8{ zMl{B+WpTlLSHu>sUAwk=;E!)opCq;^$S3QGR8}(pw|?^#F^mTLn=nIyw>vi*#;P^{ zr=Q=3reoo#C(l@`KAU!ZU0q!Rr1dlh48&c)+FJjYCUkc0^`Q<4R zvkPMCPvLwBwB&^?f$qO3AdVXw9G9KgU>H!lbs~3GFif^%1{%xp1 zhuqQU-BU1>33R0~*-<-nZDBu47N?jKfSA95R3l<=5kvqY?LXa)lb@Xn%$`yD2zO{2 zw1Mg$)k!H=q$)T@qV9Qq>Uc&6wE>}A1AdKABBmq0#+h&tFa8Dfs9a={LJCvyHhGxY zLQkNty?SOuLmb!{Y5ZE~UZ`x-cR9VZ~iOH2}*j5Kr%UgS40tBzcGQ6bH5bfU!$U%ve81*0yY zGr1uNfDX#06rBMQPgHwS_j|P`SJNBCD{MIT=wgF_pV9F~zMY z!Jxva1_)0=XgOE#V5}aD2ynbvTGC8~Ati@A{Cuh2aL`f1DW!i}^haH!5J<$WwN^_k9mK^w(WqI*xG& z*^k`be10;67K-I4DTv0##>8AiZIDF??rl;|sb?n$lR{5d`9%Z$K0bvt69WT+fD_^| zt^`6?J3Bjjm(bM~KXGenvm)`SSQ)G*wo6+V#p>MO!$O)5aJ1bhA?w>fqQ3FT(fZGl zozZ~`$}k>yWK#q5HAOB9fPugE^r=%OuuG?M5aQ^eldcKx4y^_8$ltnnd+;Oc(}20S z2~!|=3W{x9)s0gWWE)kLkwWO%jN%GNRuaDw>Fp#2EA?Q?;)x}!^*%Xia56ohrkxY& z-NK7N>8uRtV1oo^&ptTl2AlV3B$~CZDitk%!v7jDPqGJl#Wv_0(gqw6v{Gn-3s?vi z23E7VAnudFACRh7POir08J=>jVSQ{r?54>R5i*eT7(j3v{02$&4;UJDel7r(I0lg- zvQvJf3(18qTt{V?24v{`&Q)bwFoFiof5(P?HrhP#{6kc^%JcrY?Gtp{V`?7&or|vX z2`PAI;H{VJRyRW#@YVxQ$!2Q3~)r}_r|@#r+Re! z>sG9|kGRw~Fige4l6z(Gj-M#pHeA(EIII0IpQh#^XEBgh)PS@00*3m7F#YcC8++t} z{D-T`@JcFJpUnZbXZ@kC8KTV zGATshCechR^!KLJdB?fZtTTv)t|WY|$Ao3{7g!#~cvV2)H&t(#KZlyfKR59y%_|M! zLy(=y08|%=t`3y0Scx-L1$GzLoq7(7m7v4FqjXAbCI=k7-p3- zXtlJfR0-@qh*YJm4tz@!yOG5bDi@t2_b%-mvWWm=y`VH46jeWL3rT z3x&IsEhfKbvj+I26UJic7s!O#Z3@D z4~2dJ8^f+tDM7oWAD+{mb&0WYTFYnv7-e`U*CsfrXoSg{_O~|Rt-ruKv!@(GpvVt_ zVQ^qL-Tzat&^nQvR6vJi^Nb0b+=Hs5;egZt70g-KjjGro6k>53Zlgvxgqf{PGjuIU zP=miUrPDztX#pAg;(|y<1K68~N75{=yQU3dUGf@1q^AZ>5+4$LhCsmto62%@90Rn- z@0U196z&NPryZ>iK&8VRQE^7(c@$0{9lJwfuDYkImTBGhKng}9wtf8kCQ8y@zJMHo zLwxbI6@Gx$FZ1)+P*w$SfVu}mCfCR83_SNN=B`? zy_SyDb`K#-Ahui{5+ry+%mZpZkBdQsSui6VU8nfExMZN8c1Y9#QT$)wvhdyFz!Ks21m8r(o)};`?yuw;oQM^*` zu?oWtI|?3TpWKKue1HT|nCw6!C%)iE_-ir8LIsqwFS47G(a$)QYO+_eir^I<+rYo$ z<`C{AI(=^dM+nTlv=tQ~Uxa=#!9W&NTHIBeu7P=Q>>3&xEI5C_<^}h{OgSHFXK_p@ z#QMHYXc!VAZit_`>(BIFg0ARt$&9&cw=vw9+9ZJC53_ND`S9Z4lzfm=fPZ{1+5QKS z9i09zWI=h-py_T@_j0XCNCXmthE>qI95xKM`U`&MG?;;McAPw=U2T814^K|E(GPQ= zy6-jl6N6uD=m}SpH&d#Dap5!w8EY&|W8L2sNwwr8RI8fKOgLBelQ(0I_k{mGg$VgoUX#A1(PF)?x-aFY}xL7h53A2 zJMws%g;7|#zb6!+aH1?j2DUxVHmK_3r?- ztwR-sTAWRqpJQ$|eRkHi z9XlMh&+Gd1OdQFztxhdyZYgN!CNqFrDsXa}_+jNc+>8n$f$HA9S1M}PS`^P_Q7XN$ zDo=3N65xob$r`#(J3k~ZGN`xU;ygW)Jv!JW0J!l5DzlPj7n@fPG|tac2$LIV^ts57 zl4g{$IhcGLs$L*c*Up_zZdu>Ie+TQpQ_~2|Bg0@gvt*?8-lsqo?YlaWb5jc&-<$kb|0-k zo(Cfa>7$>YOGf5kq4?t&Y{?Im>b6 zpz~9Y(n!BF8$p_a01#)X?}S&!(QqDP+(F7fJ0&NVgp7lJ1}rq`8liyQkh<${an-S3 zBaSw8Pv12T%tCTk_DEw}@cIwl4)aHkX*`;#rqxG!2CrPi{pS8zNDZhTj)ftUY#jNa zA-{kb7OKwzL$~^q6{>;w^F%WF984TlE!03gHZ(S}fEXjU4I~1a&zyN)`788W;tQmz zG(er8&xzK*koxrS>L|0*)<@#A1j0&o6fsedp0K&WusujILk@?Rhi~)-cE~F)pQ)d~ zk5x%IDySoAjGxCmir&^ll(q;w5gr0n)ZB>+y#%(;!sURqZ8!)6#smbOeeEtO1FZG8 zsOh1AlT%W{W&82TOO$vIr*>c|;e&wH@TU?pszylZ&wG1&=dXfT!VBGjw}I6F2R6fV zlLyLAaG9Hvstr0u5jXUz)_v_xki%?e5L=p$52`=S$e^xgvf4J0o`aeMr}`Bf>81|j za&p3mJ7mZYF)Ns@ivV7Cz=(j1_OG-}c@UV|%L9b2ex*tLwK0l5DTA`$cq-B?`jB9< zGI)L4YjR)`(L5ahRIdE6k|G^qgc3R811N)yV3m$ogZG;^kBcE*PA@Ay%NQb4r6i*) zS{j&a9Sm;-RB6*|TBU^+fki+{0nd=l;7+Nq4M%|6a1o{HOmEWS4WIENEGuE01Xy-W zFc(@$Mnc#c`2Yf)^0spJjk#>|H=IB{36&0mrqQ9H+t_K02XIkn^Dlf4;-jKAK#Gtxmv=U))t;`h1C8;f!T~;ygUk12&f=-D_&qkv>Onx&CCdoC8-&+bl+NI| zvb2C4UqVbBL#~y9``)zuM&(SR{Ok+&akrG`h^!nic4tw-Q^lDnfmH>M{@@n>^fcp5 z5&LNHmB6x)?NloUVy0mMen6>Q3q4V<7hpWa)N`w=b+An}FpI}{wv^F<-Qx&O@m0>) z*chI+FHr|>iz+52w;2)WZI(izznt;G5I#Md(+I1YSWeDmW4Nb2pyJuVi7-B^EsA}> z@7MNfvoG{S!~QH({sb5mtRIy>ekxBqQZDe5i-^CCteRDi%kMYagPy!8Q3z+cMwB;qbakaE~CP%HonUA ziIW1-zdL72h&W>w`g<%p_Zs@IM@eM^{g3Lgf)vpcWbvtcsx0 z>?j0v1_l^W9x5s!vCYl_@vpRP9rmL03lLKN^X2y186aw?LKWf#W6Chg1R=rt*ah+~ zh1M3-#f*o;&M3!eQ8uY-vk{sK-2Loqf&fv}Y35uyG}`EcYHU$rzv4A1 z7(5sHN3aE!~Yh)sG_p2AD}z_ft_1KEbhF5y;TIGqoYw1w(d?(3z*Ad9_FCwv5S z1+7!2P^|f6#5>UV!8_%GtE^}-C&W=&5^!%rmWHQv zwc@jL3qHl#qo3=d&Y@#qfoj zAq}`^^{Dh}*k!yYWcxibbW*Xme;j7TdFmcWL};1Hq8x5UM65)kzzVS0o;aJXOBu_Q zXl&npML_pxbz(Rsxm*i7*^tAc3C(~UGjO7%2br_InkAO1Wv#-gpGuG&@A#mN-uQf z=*5|(pPGKxCUi<)(hLFtiL zO%HtcK1eJL57I_}z-urt(|3+*&&t zx!ue;z`>;*`8KF4JCvn&Fi|b7C#`u>O#Ea54zo1ySy%ZO^dAsHm3W zaO%q|4mls{>3jC^keNn(b%aMWo;d*DnuqDF8?+>5e?ggGHlgNv_wkTklv5=J6_IhD z_H(_JSb%L4%I1$!8^Hokz;QgZFgb_UKL#Gp1baJyy;KG;2G)o3IS!&?V<(hSZP)z* zGO?_c)PJ(g&Ax1Ig}`q>(_8o8gKG@RJgfg&WhWwum?Gau-@ZpNnP5ZY{@RIEA?(^0 z#E4)JC#CAI0C%}EwneN#_V|}qJ2!qM zmgM5A z4KkL&O}2+?^r6^dL?#@qBf4FShn^((8YNwC--P5+&J_q4|K2ssprbQByZiM#nWNdE zVVj))YP&?bI_R{(o5`dbJ~BDWh`>7dR;LUg(M8CKxDvFNEA~sORE1ZTD~#@BUW5_KV(P^s6Egurrt z&z*CX=AqwSx}&Qz{u^SDv*8p2+v24FAAMcub(MmV0~knr6;j%Z}!Y&>xG$TcYto-#F|5LSAwT`hX-;vahM?$e#nMa1zm?Wf`XuK6_2dZRvjJ2-m zhcpLUfWH;aeNz~~F=6DPDTV9%ts;Ibx>>d#fGoj?SE;I{c6(`()h#w9E)uv_GLLz=)n&3ISR;}cu9J9?sU_2R*y3n*3Ky$KS z)zddKi>g~1id0=;xQX{+?PruJJUV5XwQ%&uPnpb;gQSf#U?3e6zSVBP7%&AD&@{|# zJRt27#8yJoRBu996S|*qFa9$ZU^dx0fgzC@47%E@iDo^CknInhV;mgevv#YX%sOVb z^6P+i75I4P0f4EOj#6C=rwqg=i~<8sxM52HiYF2bA$@A2eq$dq*oZKIiFW>aMhI^( zcC&TdSdjtKd{(q1pcYGGlBdm)X23wGrC7dsi>Qx^Y7|6A>@#L;1%wki;OtcWtABf&CYYX5fQWwy-CfbRFR198NK5NS^9wa1A`fiD)WZ9F?{NZta$aj%${BtbQ5>h@M%naC7O}t&R!T`r5PW?#f^z6}_*#++CcA$~Im8)@{ z6h28z8}at&&XPmwo81>`!*&FkTnimH>b1>mJpg&t3>3A}G80-i+BuDkbmS$^yv_px zD_97!0L`-lXlnWbIWmf)n2s8#>IODK8e>)K@YFC1)q6~9iAjflBJ4+C@WV^uxs0!e z+djyD(#vum&mJ2GlW|aQy@AFajDxMM`-X$MfQEfl9>H+7X{dZIKqgAwUdStO+i6ak za%@GYtyP;YIF65X>8YFcsZH}R)M%IMxqC}~^+s_{uj7xhxJ3;cdf;%Y}_f zv$#X2#fWP*u$Q66YW-&eopRM4-Dhq^+#Cbo+k?nHF*Oc05a7vX<_z7nMcJ16)u0qx zwQ<*RpE&cezDp1McXs*ija%DZXJJ+AFZ2`XV@dSP)oq~2Xj;PXx0_<>L#VkV1aOya zM-<@3a5sp`Ou>p#K^`o`yI2ij$S(weS0^!k^JWI56=@4rIPOAGF(c4bIr`iO@zP6O z*)f9+@Hm}OoqBs9)W%RChSAp^))?g3)zXG~5ro$^F$HD2A9+#6_`!RTs2VZ!Y&To| zH1zfLV=A0)6kB#2z=P_gNsTln4`}N)Oem1OGpN*%xE%8tKG~eEc)#R|*yh-Zn)#~N z&}hrU_=oSbt~m2yMAYJefn~U;wXA0n{M*X2jm7CWkS87{fZXT&ivy6`R4laWa z2Vyfh8bTyZbMS~wfR1Mi0!}h!Jrcxr60SbciNTHBh^LPCe+-0K=Zu|sSb2n+d1)3Z zglWlEBh%c=JdHRv;|Zv}2<;!0(y44dTmEr~oowbTZX` zdy*}f{zmp*4MJGL=NoI})MHOk-sQWGTq_)h|ebX59 z>M@M-gZQwt*26JnM=aN`Mt)S70m324U~?Pqx7R0)ac`+Ei5)5cY0B3>{g?e^|kP{_0e)l>q$U>2yB3D)g&oVjc!XJ!>)Ww({|1^)$AXlb?*l5|pdFr*k@ zG$IamzKizGuQ-oSPpkz{g1`bRa=&6kWMw}b`HlM-A9E{vQ)tYbdhoNhcM||<%FT$%(?MZjHVy_>M5}_> zu6z!CA*=M!pmwn>DHT@C0P4gGs#uu-1**D2=v8ZcQ^u(74rlEWD#~qIq|oLz;pP zgne&{6|y=32HwPU6di)Y?P8vow5F=H@dIRxz%-Ody?C<0ZwDivHGma9WXl_QX!_q!FET6eRET6QKa1tJDg+}>+Xyo!CTQohEg;Qoso=Cm<~5nELBK? z5)xNDf&V#yC;fyfo5`EC!Ae`EuCaCrjWOY5XTMxAp&t84u;<8H+|A-tuCJb3 zkpYTIFJ{r`V?S5fMHAsNs)HmOoRSs1`$1=?ZwT@9-F`0@_jwxAb6w~v-uMnU3~G-A zF1|?pG=Tpyv6i7SwsQ$M19_T3uL5p;O$ibNHCV&~4}EznPzIF~Lv8zqs$N1vDS?8% znZtuCP{fFch%hIA)RCmh6E9fkOv8AVCp|aooqq{cCx8j+2LkOm6Zhw(q%r$h=nhDl zY2Q+^AZDoy4yB{Xr(BH$p$mDUzk=F!XztFg=_AMzsOx+ZgDPIGS-4uo^_oD15q1*7 zRg|OtLN$65XZ-CD=B;ALCq@QD>8hsa>@D@te@~2f>>MrP0!uO6)?#> zke&cAH%vEtYGe?g*i4=mnPsB5CZIbG|6I_~7QLrp2AvVNWy;mHcZNGxRNTB7@1K);3AIG%y(rZ`RjOW}c(}re|NY$MLn@C@irsFj;fDR&i6V7M%5Q z&@m}bpqSWsb1eso6K8QlfK^r~>p^d#bc*;= z`%NMZ>f<}$J*XoJ6_EIA8=ZT^V03D-MwV+%c{RdybEHQ2hQ&!WHIQhpwX6G(=)}qv z-DSNCVU1?I(nTX>A^D1GD3G8DI1k+?hcwd$-cU3u%SHCbvk3DjZ$l77vIGRxH;>0q zuWjl<(vq(CnPeoJ7#i$RdV+X;ZUqGeRu9OgM5U$o5N!%W6a(;Tg`*j#x}$2G(*9b5 zZmt{9M}qK0vlWyn5y&TJPY|%A%K-Z^YnEAZC)STwY;5dvTcUhns5qH{R)LE|VfLzI^`!v1;fiim5RkUc#C7u8a>V%LiTTR&NJWg?@fk4+-t~8;3qW&tT`^SoLre z;wzi90%&snXgOBxh+KR43L}C%Hu$hSF=uP%USdQxZQ4XxH7P}!W#oZoM3*itGJEDH zAaxNhz{!wg84R-z6m1;pS5%Pbyn=b2b?*>KLzaaUx2-ze_kAH0*=$TQ4$zd&u;U8E zRTF2;_(1#V3-+4`0EQhZxLaK$LvME23;KFOVGP;H6a56;r^oJu4L(7j;PO&lcof1p zF%qNDZb_N0@eB;u+asxH6EJ*LP+-smeHRy_48J}g4eAopOweBk7OJl7=~LMhTXKO; zY0p3^LJ~;SS+}e-m!9(q79CCY3l@x*e|A+<4T`sA+%Nd6@Xg+dp*|k$mRd?I$9LVx zTW6xAx+>}Tet&E21CG1yM$e^Jlh!+w?bIzN_4ThEE!-}b$0rsBwkGG@|h|Hk3) zJUJ&v$Nt(Cc{|_K?!PA`D<{W+Oo>NQQu5Wni?Q*cvGF(*Y*?9>xw*=u@gAD1rK2MN z$2CdEIxuUm90XCJ3J3@|my(iVg6#I%Ou^MEjiXE0&;3M0=p9IIzGw^#SH6FLw;14M z#o!Q3d3{<|B9UrnY7+Yn62jeqiV}lQk=?N)6Xyv`O-<$N=#niem;i1 ztgMVv7~P#hEFL5ni|c_<$ggN2lZ*4e$7Z>Lg3%2b+nW;?h$g>1co6=VF0olzS&`ZhnehoSG#Hv4L~I}Mj6_R(rgghTsU9( zK_)JR3B>ldZ{KoHLsKnLg{0oPa`kF+T82}nJ_Yr2pl&~O2M>-r4s_^IpV7IHkdWKa z(TffpIl>8ZCYmel1&W6~P+0NWgYn_~|3%ucl(yNdAdbeNw zI%v{eRqbGfoq!<79~K_Y=IZLY8P)M)PEP#4L5UqO)?vH1&8i>wK^gg{CwKr3hsJ}uoB zLlOYY(zssA`lJ3l9klp&#JhmlpX{p=)ppr_HY)c7lPW9MRkvEUwm@*`3q46A&eZ$+ zA6|QqM-zUpv|DxYW@cpI&#Xw^OR6{6iwTb&>ytMZ4%oJiM7Hl{ifq_07wotPiq21; zZj5`4O7G>cFn=6YQFzk_Gg@w$T3C1iED+^HdU?F`UR^kvk-8Eb;G^j6KpFLt0~~RW zkRR4BhpC_%*-MwEpEz-Xgx`ovlNdDdt~&ljg->m`bJg0lR6Bl5m_Vrm0EfWxQ9m#jMTTllM>EOj*T0U-kdp znIVeRJoi7HH8}k_{?kZfne>wXlr`SZx$&QV9Z{?mdmI@pX(uNq@^}Xh))L~xj^p9^ zWkS;yfHc1XHb0or9E*1NiN9ESb@iyrPFc6uD_5@k{PAhye(ygQswgtBPV4-yKjXpo zF0mL`f8*HLZVSckxk5q-U$)U*^!cu~?}eM|e(#_A&>cGQonZA)SNT7~bN>0LDr~R+ z9(O$PhWLeV4ofh;3v%nyRoKYX*$lvPJ0W5EjvYJ5V-SazeR)tw2u0kS&z~1$W@bvSTc--L z4o1ko=6~hn;yVBE;qqzIrkNlLCdAITaPW-ZOBvR+6Y z!^uh1G^*Ukah7*)LM<{6?E>dKJQ&hN_g}qoMfzYNH*gQp;1iQe$m;^7^(E*{rn2T7 zh(1K$!$zm5)e~h z&@~S9E8~Hl|F{tbr24%*uK&8-^m|m-yP0e59Z5qh`+{~&ZxMF#h(x9LSbNn9n)!*O zjEy;-rD|<$onfnMV6X@S7I6kk7-w&y2C$BrH6(m$YiqM_cmHf>)N)xzQ5+xYx`A8g z$8{fSYCv=@ESK@Q|NXh^C_fJ5TvR-RU<-nli2dHMvD$TIo7mfmiY-8#(0y?N?`#vB zwclG+Wn5~}qDgb-Djql)NvA)u@WGk}1zSK*!CR#$sxWyu35Xku_2 zUbS!EaHGp*VMVuB;Jx#0vB1%q!-vwvtP<2=CNa{hS6`{U02Q!S)>u_}`yA(f-|wS? z=*gvy*}~FNPZiNdfff}NPNJ+4prF_Da?ox#( zd0kIOP#`w6WbrW{iK*0XiwArsCguWq^KWC^z@sNmJk!$B9&A}8I zKn~h~+k_G=_wGoAOCYe}QsSeo7A5*KZ$K2;4TjC0bRq2>`kKr-p|>k9yHFXI5T`t`glsCw=dl|ENFjC6kd_%TG7FApvP z*!$kz{uR9CBxXMBe$R>AM|pDO9T?7qkj3FeQQKvG^5lt$rR4>@=~sMi&U94cSMCm% z>Jmrx>MXFik9Pyr(o7n1&hC_%#iN-GEX}!St3UJe#UQWf-b7(1&#x^&o1Dy-6Z5L7 zs;C$N4?VQVZ~}lN+W|!4e`(wywf0M%bac!IG)7~m7hv6v)X1!WoR(%Rr6Cm9$vKA< zX?CtUI@QY0w~Ou#K|004T)m{CVtfa|gRU14$`hi`+fahh{%{kav87N`m{?e_p*`!y z&6}cVFvrR#RD^g$9&}!{4vvn{E8p6g5gl#b$hNERD(_?#Y}IorqCulBKv~5*Ab}(w zW+>Iwm&{aB@#^vxY&`Y#_4Mbb2Y1JeA)BCcifOoR?1o7ZVzY>{Cq1WvYo^vOWGZBG z1QKL#Gt!!=3^7dDoU=<%x%0t;gA|(Jv1Aw6UVnf8M+b{$ZF25Wk#z5A;y?!{CoeDW z6WiK>v-`{MaAnLAQ72K;%a!%x`VLS2;TN3 zD5Bemi8G+2g2VFCx(Jlv(C8!Do9DSmvr&D=jxBfyNKi7e*Wmp3xps}8+JmrJVF(tt zVSbeGXGHzu=<)&$vXGRxc<(uL=ZdaeIRiIZ1;jEb=VsiQ;(=P92~q$;WrVDe5)%eP zv?m8X@?(k{fC4^$tx=2K;Dsq^NVI;5^Tq#3VL7+B`J~EsL^_0a}7P|04xeZ zZ3by_F35n^K`)H04xw&PYW`uL^O`S40s2cx@iF zttxUwzc&F!OhBaXr?34|$aPy1)bmSv(R|2}d~wAMh_AHJsDohsrc>NAX3QY>jsRiB z8HR?2$#lBlz>p?mDU?OYSOnKZ8U_H2oe%8NsOb!eE#rYO`4y#IIv=3{ z!1--zW#tX>XCW$JxGeIC&)Bwtw*PkS$+83DFkE1X#=Sjt59g3j$4j%TS+viaN48ch)<||;Fuml zHMTc?dC<0t=yzvzyX|QXnIg%{sapuuYIJj}ojJ1!A2R@A>XOb)a|H#-Na0l9cOC_| zF&{iYGpJDI0yb&w+n0yK_c#u=GebXF=%9XV!G&ba-oK=_c3x0W5HaQ$4iD8-nd4_4 zx7CmZUfw2#|D@Nf*@C3F&IL`8v)f(e=2%SgJ z4j~D97fq#$%dJlh_blG(KZga!KKd7KgX_9){6Yb6F(PndcjrYk8K5hh6IU=tP;j$o zP)>Dq`ks#i(CLmVMkuF4B2z@GD$Z2aj~_qc$RlM?5a553!wR&3LgaYi{P{`HeTxKbLpNrG z`7ihfl`3iWy6}Z5j zk-y!`)6=IJcfm6zW{HihZTIhAr?Cvs4}`S>o%dB;8Sp9xy5a%@19f$b>UZK6g61MO zW>~XWlnn#*p^;=7_K+90CE8nP79VxrE0)$94Wx2HH8ca2Eb0mJ;^YGb#kOY68WJL) zwp7&p6v&*g9Zn?B4P+Nb2|q*(vn6Jz8uwBH50McEnA+Eg@YVbWvqIg~U<_&rCAbgM)5_w>g9%UT+ zjhtze2ba{Ms2@6Q@SYrW4=y zFdi8hayxhJau|zZ8R}zuj?1P6!~%CKus1lSsYAH*i6_+Pn^ALJ1=F7Swe!|kP9qNq z+=K?j@!>$_?pZub4WNS}GehuV*v3-<#hn+JA31Un5I-6PJ?vl-g#j6_f@Ta-P4-#7 zlW}eori7#5RI=g)JR3SXI?xw(xvPEQ$h*IRJ@@2=1>Q0#B?XGp@6PKovbS#EKYO9Y!X7kco)K2+VOM1XV1E!l_pdh8WvREX&j*0^vHRl1kwip{Lf?u?-U>SNpejvqiI()|@_KAZc z0<(iBK4tal)qs^t(^?xEJi5OYm_zbLJzS{tNjP@kxF!8D3U1UX^5u9<3q&+RXh>P& zup2CHvYnIJ;?qT<%&A`^>u{FRbU>7e$YTH(%y@u2%tQVUu`+F36z_;T!?vW(5$bnF zO)Pro&Dd79sZ-6d*Y~>O7R4VJf5C_T3C*}{ZJaPWZ=vL=hC)*&|iG|v_b&AX2k4q?unC5zNfLp$-KVuq0 zHe+B!X~gJowlU5fFA;W_5UwD@jB5hlT(66#zo_xDMeHThAaoidmmiwnj%;V)nl-#Q zsmJ_PJvH2?_N;lf!%Z42jFt*nl`YpvGy%X`?fVEUx(#%dPz!?QxGY!o{{3u{tbr0F zMuVE7#~vxV^c_MkiU3tOfv-xFmn>Prd8xF76kyvIi;`Re;t4Z^L_)L>m;0bjfrN;i zpP!#dRRqOXBX}}oM1rXgx%1Ty7CS($jf8s+l6vi-8z`I?iFr|LI`F6|kb^yM03-){ zzq?WU6i%W0aQjX}etU)!W~QR?qI(fN@N59s={aeyb1 zDI&x#sfe3~^a;loAt?xTnuoj&yPI(B$R$LPDmtwk!Hi47GfN zn=fS`RL$CtlE?+zt1q}jnqrGSn$qKtR0of8 ze``U!AroiW4I6-_8IUstShEiDmN;9$7`+B;h0~%~uhJH`(J3{q_W(zho$&{MMQJBi z+HF+k_Ku=IT?)wrm2%*Il~kIHBfbE!hMLuB&rM7{qTV|cIKj=mIX*WJ9ea21xH$8_ zcXUYhk1t_or5!>yiD-$@t0;M^0EO~P>WU~}GyrNao+&A7aPWDcFDVb%(VgC96;O1& zL+9p!QGGhf07WfE^s&`*A?tjVtqp*fy&(bsleDGyR0a4UZb8AN>7PLWbPsi#Q(iG? z$~5L@>o}I`f)AOF%V{0rkT|nn5^F~p1vP}g*gd~MDS+xBOUBhnT~d{ahV!6u;C0tY z0RceyHd1(Mth>A5ZNw_`DW6}sQy=v<*OStR{}LmSGZ>IO)O2U-tH~%P4I}CMiX(x> z`a%f9P=g^PYeCy(jWSPgpp5exiRis-2hH~rWIBLP(O_~t7YfNAYHQ6=Z^2r&37FZ< ztI0F#`!ldj^3+5~g!$ynq>pL99Yxk&^l2QHqNpx{y(90!bBNYIzvOepZ2&Aol?Pc4 z&y&79$(y0v@&TlrjAA3oeugm$%*(1O?K|zq934@(geLE5 zO-pCz3Wgh05Q@Jduo2~AJh7pQ-e;lO$S6@4m2%4)8T4|qc%Z@Xmer?;TCOLw3ZbrG zaaox=s6*{-hIo#fVaGEl)H#3>9o2UjC5R46lgl77pUmWf>+kvGrn%wtKs<)4+uP5A zY{+8Z1+r}N?qQV3*E!MFVtfIJ)b@vFIB%}+0sY^II3f3Y!4h$CE+A*3*J&5Uxg!5S zdHd6+Uk|_Fb5c(ob1Y74m}Rq!KD9+4Al|x#i1ukPH~yr4&jG+d6ls ztMl~r_j}>mo7ZC|U z$;a2XD7g%4kE*zhIBYpn)m4gn7Y(joy=G1Ri!4Mn0sM*U-lvWi4yaZxJvL9V$5p&I z2}_vD#dHAL$M-rL?})%Q6}Sjni(m_i(*jN;a0@Q?Jf(&$MA7?wcaf}(vbj-02k3R{5` ziwmMxWwa7{BJ6P@OqKSA8UJ?Wm_D5kQ&B|cpyid81x_JP2vj&gp6=}jtsuVO1r(%r zaEa&8S64b{0UI+cPg2%&M<4`q;vwV$hWDIYF}X~zyOC{gx&+%r^ep?LPNt?hJ_LTMQ3&bSPCe=$mE z4(zRTI&??^PwK@;+9QIDJo=1B4{UHg6TiIqfF zNJ$jbIq&C6ch$d8!}&tt9sUZMKi;Q!*8L9tjDy0UBZg7Z8-ppq|IhRKeqSh!=p08Hn*>yN}y$4UL(|u8?!M zovmVucK{)yH7*cvof%p?>pRMtdn@FqAOY%u_%CH!%fJirs7l-^*r9UwKDH%gUeM~P zp<)Ap(cVKEd9cZ5lmaRa@@l>t8S0LHsts<(JScn+{k5o;orCoQ_DQls%Wfi}ku2;v z6gtHji+!0vn&k86&&kg^?sp+#098fs@*AtaBK8NsN#F}O7x-7t2AaZA;0tlL9%fn@Np$Fui%z^_g zco}TOcxd_*NG~DOtT;0-R+6fG*G zy?XzVa<r#n^!wt2xAT&>h11#USm)AbR7JOxb+rPG{uxVj`R2&k*Y9} z*(Vl7MTKO&n*4KV^8e!P&Es-h+qdt_oGe4;A#ngnUFCvEM~3|Ii!uZ4;lqcSr0UW+rKA%?*yC?+b{0pV^dE#IKs9w9 zr58jA?TAXq{RYU?5x%c|cTp%AulPTb&@T|(J`(5Gr*7h?IQ>U=OA8AYe-**pNz;d6 zzNc_mVQA3L?)$6Hw;$!To#D5T4$I3cR~FGJmG~kyA2e95y9IstnP(e;B1%jWfGdN& z`iYf8TpSIQwqI8aESOo}J6F2lR@dc7;Ea}xJJjY}H}S-0>A|5i^3OuDzFM+zdQ47E zPKgU!ZWEFzU|NU^a$kz0M+e_i9f)Fpn&3J$unr(aSluu%IYa6&0nE{54em0Z5sj>{ z-H~d0sKu)*7StX(?D&LJ!$WS-x7eUNz zZ0|tT4Z$pA7-h3wGsIsPM)|ly2j^s3cGNwx_>YA<{6_asaR1bSmA-vRw*?og2}~+m zhLrgRNm^mrxItJ)sz?Llid7P(V+zh~G_>fi7-)A*&U;=GjSG!ZdRBuE z**Z6g@AKw5z^0v@oxE-oD~Nol4qI9R*N2Z;M%ZwP?oh0GM1}~w)s|Kbp!yBk+SzX2 z+g)u>74%XYGbSl1X)tX16L#0C%St8>Px?~eMDc)@j6$lN*QBZju3qq=c86U-adAL* zL1oYdAvjt@tlsF1OXh%R*F*G~UVqc-H+9#>P`3g_Bbe(6WZU*k$O=@7YIH^^ z5VMHEFNYj~-9Sxw*tEXJpB4MamH|y{qX80&@iWs+Us zdFgsoOpJkxtE;#-9q6~T4jnm*(pQ>a8a1o5uj|UwuK{?1T20(#tU-0$R~FZmpFeCm zym*L&WztVCnh=AHj5d^+2F$N2^pau6z&<2P66+WN8(=5EJP3(P`K~_DZIX*KCy|Z# z0_etX?GK#l?APOt$&w>BpzM4Qxmu9Qdmzs>omHrL+YHN=!;Fk%^pnECvnTDYc#divun=mQTkmq}H0@!iJ!oRnlSaKL;u%+O zsD4@1<${@{Uc{acguugG^zySvF{gxnd^Mdj8S+OJVNzwLq?{-3?Yd*(eCB%4DC5Sq zAZU%#cyoJXQHrmg>n3*Kj;vHm@Y!9(J8n_T<^nO~{B_%JU~6!gn_v2r9%oFdx9-F@ z3&ly7O>u&WOEM@(juG~Uc(jWOSFk7W^PrHbl?^=aRyT0nL_(U%ZC+2a1L&?bj7aLv z^`qB(W&SGwL;qG1ht6U3q%h3IXU2wYZbCsn3|jG4({uiszSrZ2wFIV$2BgR z>b34$S2h@>`3Pc~oF8~&#<3Sk~ulJ@7XdtAGwnJ=-;8wpnj zMk?;n8W^tJMMdGO&&lRN>ImIfSQF$CFMxuXPNBu~S2 zUVxHRpmM-#u{b1-$_xV;!$O4s0cdCG*DeF6(D97)8VA~aG9v2l#weW&->tzvl4qaFmgR9|PV1+hCODvY>=dwDQ6%F|NZP?@d zGeZ}y(6&~4WH!g`+lN8#-FRRNCU2e-X?sOf6ynwWDf`o#gKy;NMmDnx+c)Zq**IAV32D%XD4n7Z9{8n6VXL_hT=vqs4#eIJH4ev zGfzQAzIpqC1+!Tstpur!r%L&x-gSIa*I*Iw0D{8h5fsPQx0w+6J-o^g0?~;hCs`rg z5+_bh2{7C%s{KRVT1T(5uz_LTP9edv%_i5b!W6$q!a6VEh`29J!6nUfJS*}tHnYF?k8U_4@) zB<4}H(nxSraW+6+P?b-ORTUMEt~+#;?|5)U-|*Q_9*YvB1t-WF&?poZ@W}t*#1b1w z@`~XYWRAZA`BfMayqUb7Xmn^QV6(&g+tqP(hsYbWZ$ISyU!Ug`KDat=x>Y`D%cEb* zIH)q2mZ+I(VZ2k$8DOhV0Ed3y*1RC^eaOm1o(PqIDL<6aq863 zK9?g@K7>_ydUmdjSWkg2b_TlffS3?Vf11<-8c?G3%HofZlcD`g<;CeYI_H5PR8erJ ztIDb!;GO?vSWLqP4J69>j#*qRjG;_RBjZ+l*Y|I$qR@#H7?aY|K;+fz)9>HkVbgJC zR4EDARTO#tv#UE#nH_ur)ZX9riW@5au)VtCcg&-S?_O`w;>8Qq#~%EZ%Ba??IS4*_ z+?OLTGQJj+pl#P7M~JJnb)%2=%7&B#Z#L&2;0#nakKX9={LjSBFIbJ5K!frKfwLPl zXui4su@%tmrKIG6Ujr$ipQNB+E2Ark%Sq{0n{BheLrJ&S0Qongp zGeuy@zO|^^OHW___3Y}Mm($=_ONL@cR@^hA{hdPtY!2KgUPvuCi1>XyJ&TnQMfJf< zJmI`9|IxLI!6;X4Z{xxN^>vP9WX&O#FZ=YmHd)2)$R(Ug!lkv4**2SMhVzsTGpK2K zt5X0F5;DLd#3)QL6$p;ePS=Q);(Fve?C!2tz|snLM*9Md(iY%jK``%?d{%tY2}*(N zMUTd|I3TUu(5&~FLr^*RsZ z#ryX^#W7*B@G~0VyXkyjUBIyZ5snM8u^ayLz{_$UO4;^i2gu`1G&Kt7((D~L&I2W z&l5@8@?&O>HrbqS5XPz;{4Of;OV*b&CtVNC;W+8lciz*zS-mfAb~#e1Q_;$?O(Tm* zL#F|qe=At^v-6h$n*&^rf0ve{qBt+^5U!f(_O<``0}QV=vc%~#W=sQ3lkqF0A=MS; zmXu#n&>?tZhP<0GHw2;(I+ zKsU>zcP0)NhsweaQ4d~7L9HR|L|X4fH=w4yodoOdxgTA88{ zV^bU42EGT~uO%%IVHQELQy3iY+zTX4;qK>`VIA*C+-94LIX64T*>7FfPLn4;Q;&R< zp5BDQtLZ?uOXoXieRyl~H!V3kl<2^5`Vw;A@8=KVlV}%LHSBEmvw}F$b16;!g|PPW z5aU|ioY`e|fBm%) zK8!k(M!A%iF$vpR+Y_M(@D{FnWYJ$-T6f>Ree`J z=}kEB0K1iQaD+7W{W*Wi;gvpJSGuN>Dw=4%MHaUe!fIkFmA%L%Z~fi0S5|Ff<*E%L3m0k0;Kjg-TS-;z zDE&ewMW0MHpmdQsjOZLsitNR@?A(5m77eEB>~?hWH6fIpp7*S|R68jDfAP~6`t7Gg zmkD<7cP`#6%{*e0!lmRxYv%3Ty=x9tu?{!e`rD7B24FP&iz1e7b^T=i+of~`gC4w` z%@Ca7Uh10w5#p|1twj2b#9d#aB{iXv9qs8kYIT-Y7aA1TK^}FcU00zpHk|R3nxlAyI)A%$Zl!%6AgpRjQVdv&01M!tVaG0_6ZITmS(7&pjJ z4QLIzaC*tZJd{1TQiTjHC%Jt|P{*78s`a_AXT;ov^vIsDBXKGo75QteshlCL@3mniT1Ohp`4unTA~^ov$bUKm-`ElAtz9 znmjXy#uhF^Y_J-c*-UkaXz%O*acyqkvLK{!*UtKg1J1#tqw z+`zP4v_q8VJ0rcicL&)T%!R#R+_+eGaQcZ~26P?ia7$Z&uMykYch8^wYP7HGaX8X; zIF<<&9Nmq~nX<&F%*Z8^E!0BVkPW_J!hLPv)?e}zz{hoU-um){{wQYy2&m@&T zapFWQp~@-$p^gwSUEKRUp-1}>u9k4n!06bEx-1Atky}J!j|^(sGZxKM3NdCo4B|fc zM?`p$6I^5Y!Ko5ISEK`8WbaJRZ2<*A$oV||1%=rr^p@La)J&kUZ3w^V{gB9^L6s}J z?x7{vQE0+xcJ;j5s4$dI?ZClw%atl@?PjC4s$wiEGsRmpCqzzRO9P-B~_8u!N^-jQKc#kJ*-yz z&0DYp)^I+4pmy`forxxXkHRq5sGYaCjW!WaK=O?8H4E@QvY$3W& zO6_4@i%$;R;#wSd{CL;#6DHJTuf}no9lUPM&**qA*<7)XC77uGiG%7Y+Xy~sgWi4m zc<_jnI?gGVmigb+yK*+tCuQcXWERp$MVn(b=SI`@1%eCA%*p9QH!V8dXD$zyZZ!4H zvJsK>$QWfGy?ubBkq*&AH~x`xgJQ)H+{*D zFAPT9B-&9n-_0*m{?YMtK5q2+l=%$5A?;Mwm>)LP46`D(Aaqz)SuJd|^zG;e3NH%? z9!=xLwknkgXlvOINe9oaH6B|r;5X~VwGAU|Y+lXmh3-$_XU+|ueQr4;naLB&_k?Nk zC8 qX9*KTCM$hFlQD00}v-S4hU=QYD0}E{HpI2C*HRhfjI>MaJvm6~OD`@6f zsn=9%v>Wv!&Kz){vPe1F+eD|&{$ZRNU_wkiD-OCB>;ezb$!1wkf&49A4d zD=JkXn@CE~8jCp70{TuR1*L`60u5v>=kl)vK3ZjsJ}Bo+@bTlGtQQf?z`8plQr3V1 zbIx^yZjd1;{*UE4ZFE~EIzLxUhHWTB{UjH6?`s~_jCi?&cW^cY@pBtD8pp#;msTJO zJhTgAHw@to{o(7#Xi&@oIDxDe>kb{*U(%cy3)~wHBjX`Tx9Qo_BJD*zHP@wJdmCWY zR#JeEsUpCADR(Z%M$Ew#(td)sb?oI|{L{)E5dR68Rla4E)>PAKl1?F5!l*bOcg{-S zll4;BBaVaDO)RfQf3G-RPt99h3P>1+BIgnA4Ua&yA*@RO+`IBtE$*FPV2=^33SH}= zw8%f={36zCAt51m*%FIk)qnOBg_D417zWCV(4C@vV@QMDs%7(gW~-5a1`NK-IwY>h zIAUE0WfX^@x!YLLp%gw&pwR8f)Fh4&W{lI!MvlCO=02=KwNFjjX)yLJ{aac)<% zZlC3?wo;i@K6sJ5gG6roMHH2hhI5ti@!Xbj!E-<|GObobZ$pCH_3lG6*vP`GMD;W6 za%N_xi)*F+g!-`QStQXqt|f~Y(O}wTBzY^FIM|CJI?`Dcf?8PjzoaosLW&#Fh%Q^U zZoS>-1-YRW7WrE0D3P?boqRrD_2@h9qaEgA3PpJ#+QU|d+DJDcOoH3SoP=u?^JISb zvADPt51PIsG{532wH_9RahmYc;@-gvYp<(2{%H}A@0x10f_8W+%+VP!IK(R46sHYN z`JK&b2LR8q`i*zdo9U#E7haf`@*3t2swYYgsh`Lim2#R!L+nIn0a8?hLm zJWr_zLRyJ_k5zM7SND>=+Ow7U`r+sAZ|7AiHX3I=w{C4X@6Z1I)!D@_mlUb&l)4SC zQo;|EyJ3K22(}{3_w#L;K^>ID!MZz7M>sRa#wrl7&2fPjp(T*oYSPTO^ub8r%qA54 z>5*0>HKLkYCFNj5j3xI)HdV-kLaAwmrzlPYJYlp+b0a?9{8adAq;9`{?HNxN{{7RJ zZS=dm=*R_-Zx(y)loSb;nnFid4O7p z@~LMj%ZP0-E|Uczhat`JD5M-h{nL)ga4Q}4Vj+Uf7oE3$-5)u3CBbFd;Q4GcivqLF zGC&2M(XoZ=z-qv_O;;+}x}8 z@AtHglHf7s#k6w!gibX`9M{rvAaE%Zv=cBEt-;6G6S*X2Qe+Z(K%VYe>JwJ&X4nbF z?*=Y-xz1FuWyn|X5h`j{fj<`f1Qn6arf-%QXCU1WVluMk!6ZT`UepTC;7m!=)K$Y6SO$Cj2B z%@s~jY6!|BF)7GQQMrlm-=AB~0k!$z53-QLmYw;WYK^+dsm)im@LJ%?#J@G`)4TTx z)@tQZlP6-L680{+frnx$r1CQdk>1y?|xyH4c%1#`0N)e+)}~E0sXWAOY3x#tPv=_;r@@vXS2NzlxoPTe z4~XQ+@|r4mgXIygRG%z)39y@cRPtSHV5FSMgf1DLf6Lnk$&aukg8r_iVED=~ zTyVv8X*c|kDpZ5fQ;6w}4y}pFP()vO<1+V2wLbJMDVt=zx5XaJSe#?r82+<qcwC`tT}kB%LZ5HwnQkS4k=}Z2}4@ zQG3$4^DQ9>*dWNn-0l!j*sE1OyuFo99FmHT4*NAWeEbRkmYheYj+~c{u+W`I3J4MI ztPz?OYgq+}TXn&(52qFZ@Jm}y`tu@v;pl{u;?A4gQZIEV+_7<^kCAV|pVj6*fd`+v zH#z>s4Vg&zatZ1XB$(~ScEp)L${=>HD8?iMA?5x?2IVk3Q-C9sP^y!YAst>9#Z55ki+o!0=}ofrw^~yX zZVS_u@Gu1AUyc{@rjco3hj~zCyou#ThF^pAh0Ji6weztU_eXhrx5)die}B8 zZ4alecwYH|ShwPGZMFCzS=dY1>#rWvqDR6NMnT*V#u08g?q|tS@Z2;I76AQU#}-IcuA*1ayu!ScnAd6`pljc zoe6Pr;@)!{KyHIY-^n>ByaVTZppr*r$y)_~0j`c-gtiD8t&%Ij*yaw&yclJpjhwCChHEozIAy1M5+B1T(Y*8ay%=YUva8Xh{*t$t2sk> z8yO9=a>}0{TD_p3ghDm(u2FaNc-~I_a?kAbr?4^1D!Ym6K>1gh)vt{_0%B4n#pnD- z4S<_17W9^FE?>j8%q};caSbLWgz8t74I*E=q>3i>v(bq+PE)JrtvzaaV6Mu>?31?h zBXjF*G#(aFbRnES@FV@MN^$&;um8c8=l_bg&;L&^a-wyI7fr!w2?;Vi4S)_1X${4> zIGl5r5>x#kL2A@D@fwKm!8y4>&NQN249OaAbZSf z=kt%P1UU0y$jl3jxaM{1!V{7{vA7m~mBW$%kD*2j-bW4Cm z;e+*jgBh4l_)hDxe0#h}1ThEx?4K|@ah7KSFLRZXgCMKdY9v7+quGEKV~q@lhJUW@ zdiKp)Z)Kzy>eG7QoVq;$5m*4CcOz8;IB0w33T!!xC!C7Xz6OvH4JT$8BQ(p?v|k^_ z_(5G?Cltj>PAQO#5D7(N@J;B8aWpnL2!=7>M|XwJ%Al8Aw+N2H?13tSb6qCb9s(Z+ z0%DV8H#hs`WMg}Sifm{`RBFF)FZYdBZ(EeiNxp|qiG2CAb_}m`SV<6pV3Qz;QI*>6LxQdL=kUqHGtJMAZZ(4 zN!rG7B_!N4-GV^!>1S;Px6gzrH?0^TQwM#V!jQBorsTpuKM1#IgKzRedem}1kX+f> z+pnbtNKD0KVl_^SB*Ie>e8SKRq&W@xe@%)f5*-fis@Tu5Hx8vnfifiU{<0#)Ylr$p zA-*9}Tf;2gHDTLFh@3QUYvJZCbbEL|_H^sdyOrrzu3TAY;;)|7`@IkOgLnXGs1wsF zDIg??gTqsO7L?PXMP%rj(zB*%&{*{Rut+4BGXJw8`w-YkkMT(Fg(fV0T+UxHia;mI z9U+cwFf9}RLaIz*Zej1d4d@YQYAWPgh{oMh^ZvDsEj%Rp<5LJh$0L^p%rqLt)%3Ta z&BT7PQN(;~$%sHrawe58vx@)ExX^k=ZKHpFrgqx6%KcyPbZegU*CLXI==au^mTOrB z*ey=Av-967NAG{WLRC6xzQAax7C5i@Gm11juyjQI1&pkL8$yaA!&bQ+X8Ms~bwTCbYw&M~_++Q~2Iv zdXN}WA^ZW5iaslSL@ad#)>Fz{PM=mN{_Xgt^-jP1hQy53Zv8?=mnJtIN{`*N!op z)@4EY-U8wsUJese%-pOvzzPp3I)Qb;$bnwt{fH4D6vq{A|Qu(b(9$ zO-X>g|2%(XJ{WW6ihx0%7XzF%T=DQ^8=w2RJRj$0et@t@F<^k(bN=DMCPUJ9sg>I@ zs(EjPGlGz+3S}LZMqW=fUcRA@DmPENm;rqvyxgbpUh+__nnrZdj5T&l%2QdRP>6F2 z?;yUZ4(sF~;7ddEzYnM}L2tHsdF{42G@vjFyn9%)(T7^iFtA#zGBl`N#^Ve;E83Ol zCpCYj(n;$D=w0@t243I=U>s}kdLytNc(XG|`O@HA^sZ~s_`NEQRuHd6F5;hc(x~t2 z-fMIys{VPU2^gYw3Y<_L5dtdT;^5=Bi4rjddV2`2v=1$F#?y5sI%d(}!Vk`8xl`Yj z?{4IusW>^^I&A)OECv(o%JmyG*yY;1;>FbeK?5szw1oDh97I%QzOHaD74FpRO-5WQ zzvhev!J6tz^B-11+9{@CV4q_3@z-CWCr2x>J8qeEQSWd$*#NDX3KBADrzfX@K<-3` z;3SIJx(Z$)$sWja|cgjmpz#R~w%lPN48@s}3@xk$~pCwgLn6L6oavGgW~YZ46BW zWvZ+YP8M+|dNCkI{5t!2NGWQ<^_0lOakrl95~v6}3euABdVw93_r1i;xay#Ksu%Au zVKesLJvC7@YX4-sgcD-O{;nmbPnZ>_gkfH!z}{qHyf;y0jKA-+Ac`jAT_KmoEUwy) z_G^~S<$|KTcVk=9;Z01NVl96F>}cmuq~~*Ty~ruE4Vq1)R`Bi9<3G`zKjoAL`>eB{ zF{1-DKsbTSjxX%evF#Mh&UbS8_n+g>)-GFxxS)^{8iKZ5ALNQ zr~9b;5mSiX*Ql)>9sLUT@#B25?O}OP)xdH)cWzj|!YlsYqJFT3dKO|Mj$9xRSm&nV zHJLz@&LOu&)25l!>YVyxhU@i2N!~F>)Mh@{2=ZAM2#^5t)FFwOD@=aN zA4qIz({*nDU-zY5uxZ#t5e}l^R87BvT?ob3k*(&x-d6cv<4tUdITTOe2y`je_Ro6` zGY=d@h!Z^r|JM=Elf9Rb(VD(|4h_JF*}bUh0;mf(zT)`wzr%IK6V~#!c;IwwOknj2 z9-Fu>3swT>My7ZvVrk+rkg%p~DIl6y2rlL<>i_9ew^!WW|EDkpMUfyZBorPa-!-XP zi71THLLKcv*s*`_1wPttCm&=`eP%ZnNg%13P#mF&)=)6jv^n8MrY(a2A3yf*G|BiT zzj=Gc+_`tuw*G6;4D8frYz;Xs|MQj6{a4tixUc`0U=1FNIxa`p745p$``3<9{?~O# zMHCbNKZqo>gN2`g!r$^n*uNjH2S4dTlLTzagU07nC_W6Et_{311FB<_THX?)3VNXZPJl?;H|eu%kwJ@CWG|0&1(rpM~UL!dE77zFWZ zjKsz)%bmJJoP9Z4)^luPkFpg|v2(wF?=bD5Bg}M195fwO8(!88I#yx>B>WQeWjL93 z<56#twit}hVa@&(V$@u#)gF835>TwPz}!mT{Yk`ah>YR}d5YD4{p_nD-d;|(N8yN` zSx~TziY{FLzc$d_KX*EPpGeM%J_MKbCOF3^P$7tSowUmt2L4nEPaP&+jvqhX5NaRQ z2o0IuoE+H?lGzjV_omLDUzdA<*QIO}5$by_;7q&y&Q#pH&6T z#c`l%+yWXXLTTYP*9CVK@eF23*EuuqQ69E?qFwS;CIuTXbAv_ZMJvawK-9ZUC z=(YO)DcvpEnLC zKsXy~IBf%vfhpYS=LZ0=P>;#vKm36XpI%v!(Usr2pg99?t0ag>LZd~c!+>cJ9gHXj z@d!{=(6@rv4I?H2N93SfGeCLT2vBo(YujF4(MK_rATf=s8?td7g0PTSX{^`3C2d$3 z*YM+-9c&{D3hVC!xVud9GAt@|s&I)TmK|knrkTOIQ>IIuN@>szWipfxQD?O%$|3<9@Fr;sCwr?jG;(we92Pih?b)9Hqb8Em21Q8^Z z_^XINo8*}AGG$s6_|qPnpP!a|(%*2YWT=OfFd_+IaX2x#ffQ?Vsqk^ne#&DFi1ii+ z=>~q6K_sxpoMt`9bc2Nl+Ni+WuYt*zC}m#Izd(uP9u=h%ZWXqMw7{O@muI_2 zZwa?a2LH&bao5_DaX)QBGmA{Sp>(z@zH&;>kn&-GH)Ym4G1!9>avew|IEhq3M@HlanQ0(9eB@DodqZQN3+Z{%`+eDeEATXgArbs_jRFmiV8iHX*FcDWtt(|D%B z!bGz$wl=X^=5xsD(uyzIa!GI+1pkdalK+E`Elg*GbuXcD#&dwyDtjt^R}gujr1ikp zSZBgFSeXiU4&8-!&M{!57 zn4;1jnKp}`@Y12y_PX+O=;y3^r^r!<293zLJ99ts!3FD^N_rG9|KRG?m85UQDjz(n zJBE)0@8K>km&T7Ht6?xpHS~_O(c_`zI(UslG);#Tz5v5uWa>O=@M@N0CVPUeSx3Ab zw-Vx$PTQfXpeCXCX|JQB#^M-8W8K{>;~WZ}`T#6OGHvSB#7(MO?jNx_i2D-+=BjjX zqUOMI62DmKPf^Xhu7GLmA?7V9Z35A#nh?7m9gC99Idf*;4DT^64KQQmAyJ~3zT5N4}#B1)T-G^z&iVHj0ku$K)Ee=MO6^F+@jbR%kT@yx2T8I`?w zyajt5C!698>ajvu+%j*D{&pPqI%vl3TX5va`~4?(9%DDAr2SED7p#STFa+Jrx~~WE z@-L6uMj(RF8&4p|XW(b!B)wB7Pl_j9h1M|vqC!y7%x(I&XRWg%BOQFrrTlOeG zu&+Ll{dq(CmDyQ)=B6Yc2YbiIL=U*4HpF&eL-Gu(Dv)+H%J~Ylx3k3)8M3E)dZ2+p zE68nd<|}`g_3-DLH?#r*Qlp^n4F-Nxxa_ld2(nDqa@w=*S^M=DYm+F-^p`BMWf_U} z73#!RW-X%VMte>Fy7!;~r^XA&S;TvV=?PJ>H8jh>B zm?4q?5NgED94>Ml+@plh{{8vx(8E5@w}=-cLmj)%^FxU*6B0;orGH0yE9P*Nf1P|j ziTn^S3xkO&ztLrry4km9*N*R=Zkw8tvK9|ak@w*|;k#J<)$#Pd=Fat$;MJkS@D5r# z#4x;C@YX*kkbQ)`BeZV%*;hwr5Clc)(xOFIH>=Oa%m`ot&HdQ&U5b_%?QkTxpg>T&}_aqsk#i;?qhYj0= z&%Y?tIQn0n71q?qsGAEhH4^v<*9km$Nbewgucl#*jB{O0qY}pveimbg3=yp6HJ8TrpkNclHBQM_~{q za0_YB$b;N;bNKRv|0Iiz&TxpB%>9b^nRS!EQWgHE7)u~@fc6zHM86d{g4au*Shda4 z&DPAr2_xRjzm=+oC`N`xqUjg6Uy*OH2HeeV-jJIM9oGEY3X&Af4lUBs(j-k2n@T2q zc{lSn`P_$-fMXhYM#t;#c)dWH$ng8>-Gcucv%>Lu{%;-CG_qaiou08r7H6{wHXseb zuyL|u2IwCjsWRI9utM5&Rm8|rr+N*F5KLHgf#1VZ-de0!4!;I3vxTX*lW@JMJ`Y6>Jvvn7)EcAR(@K(A$>{ zD@3aBzrj9B(^9sK{uh39z0YfdH}oc@4OhK&z28R(o?w2-4^>?D#V=qh!C$WL&+c$1w#%dA91*I+;j}iDc5DbNLz$J%y2aSc3!oS_QqS6ASjAZ zfrMc#Boi4$i@cknp^U@w^=SIY6_1!3hdR4d!XhhR=9 z49nI2o9J96eOH=;M;%C913HKIp@*C``AWRZXu?%c0CfwQ=A_bGoUBrg*DpZ4=h zG7{%E6OeWFlNq603?h*_RE(h$&uFvNZ@+E-_M*&|k$D-S>)7du2i%9EqV+V(-OU19 z;j`<(A~R{ZpG5fQ_csG`o}BAW6BG&PdI(Zbsx2p-T0p;BT13;QPro;%sl-HAOMXU; z);%Ap4H5Qz%5!B7pR~K|vcp+&$v}Elbj}=-B!1~J{V2-l9Rc(`T@jfiIfZ(KLgOv%ml!~!3yn+}AOxbAO zwoV%3SEznqn1;ZiA{pL znPsyL6s5RbD?Ttp1s1A!WEzhiJ)HB|iNU>Ii-ZzcvgzrSRjsKH6DY-;m>=icOh=qT zSj#w##wH|e@M|Rat?^q}ZyIOHUQG?`K65s3L&I$;B)#bieEw{dx5ojNtlkZw;Il4e z5-SJ4#YhWyzqMGRV=^5L?u=dvTq;Hfh!PFy$~2`N-i{^HVRXj60Yfy#4C>^%&b}`+ z-)PH?`}Xb?zG2TttJ;?(NHu{LQjToVq&_x z@=u$yUZIB}Aq7mzn0jkh=Z-&036&O>O|WWc%x%~nZcE)8o$#Don!VQKoZrAz1ga&* zElbZ>c9y=^y>kpXUI=}o4+p6b_1OLZGW`-tC=`vTU<77-dSP)b;`&A}uWoT0WtryK zG-QWJJ~fbYromTXX*yLLarV4e2-;9{a1VLqqhkNXut(C$!USj^Zj$q3ip-VX8GAi z@JGXj+DYq(-l+_dB8FQ8)}K3PjycYhVfGm9*HnPg%0K{FBb@cp>t=&D&Rel;SuBZX z?o&ChBVT;F@F;x+MCb*jq#Dbo)RILJ_MGS8??2{5L?}%@F~E<|Lutgci%7D`OkT

SD&zV`}Q_pCTl3fIhSL=;#3@Rt7$w9iYpP0Rx+HC8Xyeq za_gz20bfbjgrgE&Dw+Wd(PVaL@r72IDVg2}-dl?i7Gx5ab0(klKwaJ1BrV(dF6BsF zYN?M@(3>3CWP8G9;`vTBZ#e4Rswzi1n(lRHGiR~(6$Y9u0|Nr4d!0FXa>odt8Q25W zM^xYG&j@^#T+ANHtiPe*)_14d+nrOknSJ&Pki`@kp;3Iy`1zv=P^e6Y!u3kzeZu3@ z*VpfvFk$@oOl%b6;l{3}cCNu>=|R{Qhzd^#;9Qyq9z`` zbp;F`OzQ{_TQAx_0W3qQW_S`IpJ!_l{R0XFm+{}v2}V5nK&Nw4sHGe3Tk~ZkC!TyZil}`B_R;1#EX{k{=Cd8e zqg^Bq=tvZzdKOIU62l!7<}_YF^7hsZ<7fGD4NNVJ*s+TBv5q7{#6Xg@li`fJF>7OJ zz3Cm*g(XhiZqP}I^AB@`PYSfoxuKBhGM8M6#m()8MyrFX|Z55}>TP-U*C#7;4oiq-mIbWkxw^&9`bf z$Qwe3&VFcT@ak|1@ZzA#VwDfWA6Hxm$o;1ANyWU3|5TwIVlv-gpl;u`-t`|vp0k$T z`C#jnlgIK0&Zm~z*s_qvcbL~2|3Wwq5D?UE(TQEL;(~UZG}!s7*2!T++5zNn;^Z#_ zwpa)odE;xG$((C=>vj}VE}lTZALaLCqg5jEbiw#3@15ur2#SFX4g0h~9w1{JhwOiz z(B|GRzC{B~x95Nq6F$m3p=kkwNLUJ2t$5AN4=EA0!zF})5zE+K*2fLtm-sR=6-kXb z(Z@E`EMHNthKE^(YbPZpst8XJ;f3y~I?C_dGrRumSxYjIt3X)_1>*NX03O%NsHDZ$ zI(1*UxMb+(xy2pod-4U;$^FJD*ae0xB711}&hc3lPAw~;U115xy|8jMDa*c_K|5X%x;1nxw3 zvIGC?D&;q!)0LR$ilu|JYw`Q1dRxKVWbKnQiAaY}6$gI`#B4np4Ej{@mg{^23nas@~?QqKLa=@ z#dc=O6*nO~|Lw zLCOw}``D*X)q&o+IN5-i(3;z+wRF*^-G=3R)p%?~(bx zI@B=}FIh5AaJnSM+sb z&c)!FO4M;|Ow`x=6aYlPj(tKuYmH^yZli{o3M?T)^)YzRv)o5L{fUx~Rpt)|EU7}e zyD>DDl_u7DlVC={UzE0Jg-VZdPKC8LX7R@I2d3@+Sq{r}UjV`wey%vT|(F->*48>B>^iV4ALuJ)CyRH>wpjZKoRJfF+Z&%rg9ig9z7kZ<3X zNu|M~vUtT$ZMOj-jn`$`m?N0|&4j@G)-Lz>F09O@wPZm|mj0qCMk${1baH zdgQ|F1de@4kHSvEBx8whdh!T>SO#zXc6z( zQPs2><;@PUUN}1OsPE*3iAkUJr+U?%5j;y%({XgL!}41_Y7AXqyf88BK)o{uW-h2& zRIttMz>}Kw`k&A{9Qh_|L94odPxQN-J~QiG{>5(xF7Kw`$}T8~B-+o;x$Cb@q2}dW zRiRM0uSxp6Ch+3LmAiKu@mGA6!VqUX=vb%MDePtN5!N!wf@av6)ZHeE463hL-;ic~ z=jhi6_ie`+W04X`z(a*(#Ur7VY&@jPX_U@@< zgQJ|topHV*r#!i~S2g*C={G)Wjs*4(RDxS9#CnZwkfN?kRpw-;WYm-_qP1F=B)T&96_+f64a@ z4Hh#X%eCa~5q^JG+4i8z-c#CG;Qe=j)BF?lt!fn}#@B7CWRAynGqJ@zUAA;Ca54{G_7zHi0~q zGW~K%?orn2B~r+i2kXnXGF$mhaoW0R(-?ZgpK}7Y&1jy*vL@4h_Um^|WsPWH`xxHt6;3l()Ut! z-%3tS{`JzgApnRa^Ou}^Yj+`KWv#k9>#THVk2{XFh>6}I)^p7BYv81@BAiD&&7+BvC3UNY3G_i4eGtotzAERBQ-eM+U?sQK647V zg~Ui}E#+G%|1Z&DbhK+65Ra)XGbn{x55N&Yed0(a(YH+RBX%)IDY(PQ0wGd4M2&CQ zvgJX$In$>Hb!j{@BofuF(kYSw{F0@26 z*&D`iv-RGAvz6OXc>#r6CasE#0;iks{)%1G^*@S}SB1EUMz@7+0s2Gl!Im#wRzE`4ssz4|Dt%?O?;(pc$JrTl*DOfp6-_T z0K#BucE0jjixR)A$%21!s*CPEdfy|sV})P8ro$tVA+ue!Q5u>9#+=-Gkp}*h6>l|* zzkT})4>XBr8nb9{^ANpCz7|9>9b^CHPOxPnDY5hDJES)->)+NynDS(lVe|-(qQ-G# zjFfAWTJqM_tdMK5Iyi+p*xF{8T89+(Q}vSXY+v> zOm;hQi{}n2ion{K#&ejL;Jr+})!k%Ox8UYYnjC@YGCOL_4%H^zjcnSou#?_RE zRNVsz4oob_jpjU8o~5c0le+6R)}Z14_@u)YR(r#S4TGrllCE8AKY0?~Od4Tw(3apw z+OSzO`SXNYUQ4&kXAOgqoFVOgda<5dawq(q%L+#OTW-)Hx z%x4Ot)##WJbD3ZKq+PoO>3`I6TZUR+rhEAy&Tb#ual{UapZf>fAMe_=!_j3B55{5_ z!OgZr|2ApT`yXc8A6ae(Uy#gv5u^Psw({Ha{b2EsN*~(XGV?PsT>WZF{Jmxhry!R_ z-S7N8N%P#Lk2hL4-V5*`3itAyaK-I1!0S=;S#L~6%ayvX%<^fnmK@s(^9(lk9wt)< zpwpfOkSTw~%l&hjwe726&ByVM_ht5yO9|IQc)VsG8f+-&uDPCbkYpAeqo+Fni8ecXq%1-xEPLKQ9XoAAK%k`})3EsMCzVbocGS_;ucGks z+4eqo`IklZOPJJSL!Z%Wv~o>subDE_8%)jzy$?zUUb*Z37}j1+d=uN$z-~_~4)* zVRB!8){`0g}&fz|iTdGa;U-D1-Q+zDz||C_c0?88vE`UDKX_3~9V2zT{JeEiMZ8+i&Y1 zU~TI<7UQ%9iyq6*xts2t;=~}yyWBrA6t*wfaBRMnh6zolwAEGKel?QlCUKs;ARJV? zuUTvI)Cq7k09W^cTG`DNdlGoz-M}IDyS4>r3#Y7tp)?Nl=!vs{dp43LXnc8_lCbrb zJXv>_0sjw-L2$p$5Ds+PoP=kHM)gdBC8<%Gtks27W)wz zT|f~gFfY#(!E>;HIFOE%aZcCkk2(`GrD3Jj~lmW(OG%IaV7@R+Biyb=y><= z>Q-*w>~?jVJdsA+15T_(^rq114q@M%y%S9?89{i=Bb^^hPMOC17L)rVE{o|JbcJ*9 z%xx=mgE?zDjrwr$+xx?0V=T>ZJwkYQjCHRhdC0;=m*ujZ28#X)L`nT%dHzy^HC3lY zU(H1HYS-5hdR{BiUnYSzJa&#=Y|M+&S^{8hhL+ z8xh$5%y6+}C2$W)NqG%SEMNoE2m11OcI3^pv_8K}yj!{)*rWUmo<-& zv$cfg=DX&~qet$PqLT`@X2cG5`S5M6juRB~2zeKrUu~9*w(AgO;-UP(!&4XZT1#CdOckk4Syj%J0?4IQD z>+PmD+ll&jOMFGA4|5wQoxlzz1XYt@fYKJtJ$$*KmA7XgNA*h_b#~4 zN&e)L?b$iL*zfG?)=g70{=%v(&Re_T6}cq;G_1ADht;SF&M+gY+4{B|`+UC6q!d8E z+qZ1dO0ibHUbh|&H;U2+wCLJ+i&TQ3XfvA5jv6mDAn9z{){1K8A<0+6#*IU$6mJ2w za$Zy{A$fVc&8)Bu7(>oPq-c4to4j(HDj^Nk8`Q6_Q5L<*o@Z6PZmVNBJ_m1ExMayV z021Hc%Ew*B5tSbgYA62G5rozxYL&+*({ySxMhZ8LRQ7Bi(2W{1lusmat5!85Yy|R> zu@fIhUuljqqA0cpb&vTrn5dNw{q;};FJi|2ONVD82~njeiivS zU^H(u>>nQW)_~Mbm7(KZwOps-}}gJYbWF zF)q;*G>6_l^zrfOf2)D~>~_8KkFMCz4Bx1K57R}K=audmEq4eOA0edJhP(mL|~6FMd6OHp1+Lj|quyOd8~7+{=B5>LR}rJ+TN z%e>PuBtT=aC#f-0S&bwb^CD0Pc12_^W7wC}BA{-Tus|d^1*sB~F@aLxAw|T-#y%~t zaLpB$iXPA2*VAD$m7d0ayj^-Qk^;v00PfI zC;OumC?Pa5rep+*t0H;H$vC^%U^bH zZrn9o@L!n4E8Yh0VScIZ<4`FeSL?#HIbm<;b$oG|c*KX*k^d&$Yk=0>e@aynt~{lk#5tkyq!^ec~n z24X10h3++-cuHxv-}TJg6}i-Uhxb`flIPlZgjY zrQCY*#D=`qf+VG|UHKL;5RHITdmP}80jEh50Pd7e6S)Cf)-TXfi6cB6cBr9x3P8&0 z%I)P7v^mJR6b?KkiHn?%972@By82z^yM}!!2|h~7$M$e5`(n*qo`u!XjvzRWx#Wt4 zA3u8U-1*jKgz_vY^yajwR<&v(QEW%{>(?DPZ1?YlD&M(n+d(`uFJt9{-R{vp?1O9R zyK{w~KBWk*U_E8|FiN5aF_cTnB!SUwh3zWtS^ehC!?BpWcmpe7YbG`<->IxzG7kwD zwQZX<{rT+N+;Ct9ThiHcZ^`WZiFLVS-XTD?C9n6-|AlMS zWsvBT*JYA8a6l|tKm$b{9c-V^*Bh8+k69x3RT-o|!f2r4xG?)RmztgJ)A5Ok@-C|9 zN3QHUu}kB3%r#)>)$JY+rbUGkGWO{&cJ#XOohi0wz_rWx?_&ZsJLfn}I1b2{hyLb# z4wtmV&!un^IkyI@`ple17Q_&9L`eTYo#kM|Vqiu3UJ#Wqp6;)OUxVJ(|L){__5S^# zEFHShVf0U_r$b>Cs{f^ca`g|?zRTj zn>}|lY}6+l7-TG;W@7g`Fnup$r1^J2icw@0C+mq<(Ev z{7&lEr$?3b>yG62!3YqX7w8=6y|-hvgx+vYthoDy&c<@(h6ZN^Df)Q5!Go_-R!PA? z8T%KlOF_1YkAcYsrMLa7B@hZn?jY;zmH$-=cV(?q!{kuo@^s_dQI5q>fY}FMDRj(% zA}GQ;n4pc^&YcYAqxI8{pL7mM+8<9+4!OXIojxUGXYh5Op+-j(v4aYqyiqfss@c+M zQFVp7rr)!XcUD{&)6+lm+e7lF!Ag9UogKvTvw1?h zm1*D_MnN-o6-~FFVb)_xv~zriI`R=)X+#sdJU#E6cs4C6vBy@^CtVwloL8QYC%&Q3 z3ZWXG&1vntWXS^<$}+rt|A3Y#ugN&_V!LSa)OTrF4yR@w@&B}U=5aaaYahSOjN>qj znPHASj3`?wg|yCLkX?jSNMTa;NRm5d4r9r3+cHftQzA-3RF=w!?9!rC$kL*sMk`W1 z@9UR2X3jbPJg?{Z>v=t1FTE^v-@p6!`+l$Q=lXm;*L9J^#%y=yF}Ql0$m6{5?WNu^ z6R6@`P$40mWcwJIKGiQ?Ly;$B&H$Td0o_iH(uh(3gd?qUV%`zu%Q^DWrEwyh2;}<& zJabbT>gyKHL6R2n)K<)fKk}(+I)o#jKb`fpYGW z3+3632W6#EFe7Po)+tMhJ~biG;YFNI8P$~C=rVGLlX`DLf;uQ)0dBAV;dO3Uz53(7 z^_$4g=0ASC(|bEN!VkCT(jBJJHdH z7Ji&c0}^ubks2c!6#G#ID>^vTYgK2?q~f&0+ZpvSh4Bh?xeav(WQ$0~ieZ4ff5x_o zM!AoyP~6HZxSDfG+B!NV;G=heca_yb7zse$30DY{7&|TFh)1gi^C(UTz!AKR4iry^ znRHsA3Fbp&yIyr(!0f?p`aafoxo<@f?CGrQfzK_$`x1a{k2_?&gLkcu-~xm2EkaJr z?%0W9tweqIphaV!vR=tH7O-=LToG*v?s1a4T=8#K1LGc%+RYZ!1BrMSdNfiFOY8(f zB!t&I^qwH8GBUamO{U+)FMbjPe*+HxG!&m_62;?q219hO&D6CO{e)1YK;kOJP7A82vJtoM-7>e76=lRZ1% zQQ23@e^_bh|5@9MgwxJadaA6rxJJgZ_PKfh-~xvcg^B{yT);_~O^(#kI5+RjkDe z@`j)LqD>odOCH)X3vrfBa2UQYuW=J7KcO01;8LxE!kfYZ8Q5>oXd8UURW}DFe?usE zw@#FARMK?@6GaltmSp-9X}-Zx`)}Vth2s^0l)m3+=p;(=6`INRUa=-}fLlYlD$!@}^(5Ws%r3=x& zX$>ZEjhVMxXt!DVagXjn{w!obH4rzGb>jX8UuJN-nX6=~n+;0hGP4Yrma1zy2Xsll zpoq@6xf5S_VIw;ALSv3YnwkWvuj0UbCZ`g)d|H<@EX{&7HaL9)!d;LS>c`36S57w0 zqp|H8L@eP<&!+t})`7NKTE|&&gx0BOpVr$V+cn;CzAiB$L!jDD;rk%=rqkf>C>}>G z_1mTLsrjpDhm`~;tXbA3T0)&6dEI3vS6%#gzz2Ddew9KN@kMx^l32e4>!04;SA;5 zS$jv-4fSUcuC_&Xcz@;ETsdJT6XxHD@{c|J?Wdwu++}GAyCHui)@N`U!N;HIyd5@c z43BtzhxzR{5UuhaKmLgTnpG*Aj%m21pvCz|CoG*Q6)&^51q(VGHIp6-oCZd*1j?Ophdr$ zKHKOF*CRjJ+f@-Z%GaD#2yoql<$2Aoye;@$NN8D0DVqC}Qe%r_&W4*61wpyQ7C)Q4 zzFtf*lrhuTBn=dt41N#POKK!12Bh|$$XU1B$vb=3c2LpH8y=;740@)(DlR$rw$kgP zw$rVfbq;Kz8hL$WjZM%ZB~W1&n7 zM-y*@sL3Zrdl&0R>9mY=I5Ve>XXHz@WkH$xyJTHO15&P2_j^M33e3=FnJhCka;B8e zT5E#5OHE$nCd^{#i{V4=B$wzPxY>59w{E>p!`FWTCD#lsjv1WH9;O%v8edfx-tst~ zTI?$9?(l^hoJMI-Y=T+BV>C=(c3RnID$lHseRMeJV$bf~E@}IC7g7b$ldu;35>5A- z>9zHGJWvf45$(AsIxe&H`D{JBE`0y;x5DVP6JFXbZSi&c;TCz0^{hyUucW)1n^q#SJr2O#O8 z&f6tpfj=3g!J@oy;TAu?J=#L4(T;1Ub{xc!a6*;U=!is6=4=9?M%LL z@8QGq|1F^K1Y~no`QU|>6N^x(X_s{>_%eZFfyw0^(o!C??UOoEO!@HM~Ln)VbwPCFKV#V8((7CB-k60 zLtAldU@;o>g4E^?Mu{V!D?4besWU^~Ur)Do`{P53S)0N`H*MNfd@}!<1TZ>K7)hFD zzH>h`UYNr?tRi6aqZ5BxFCr7Jn_+P0%T}z|1DqFMiEtp8V}D@uV^&z>)ydbZbBXz| z=VZYny;iTDcdzWliwHaueY$u5_1`3g*5(Z3eZru>akt-NE~`a=&U+E|gJm6&?B4jL z0isD#cJ9W!Kf)zNdH{7D6@?nsN4RUy(=&GHP((IyYXCftu{%1Nt_=V4RqD-aJ@M2x zldLtcvS=u0*OHLuEKcCCx)n-N234H(yz>3W5)X(B-?@Pg&AjFNe{TlB!A|W~t#DF} z|K3@a&OcX@>VAUmkoV8zmw=)%M@OeZx=9`(DH6no4UV(3?NK@)eha@I`u^pYZQ^YK z-eIgziRqO7=arQ7ZMUy7vEv59%RYSubbDNH$elbx-+cx{(t>|uYf-bpm^rBet+CvXhJmq&mKbMRejCGnaIJh zhnA(ngR%#Yk)uF(jc$iU=M-2bs*#IVzJJFSUXRg~Gt}|a%n+=*m_4{b7|heqt!K zf$pBmmI-YH?Q9Omko5cy*W-TyS(BCZ=O>^uQb+k9%9n$W{`>^~a936L{MNsZeX4Vl z{_x;t4}PcjC@=DSM5BVG{KfadyZxi}S4t=D?wVi`25o+ghS)R{$-v#Z^Tkt|-(!a| z^kL+{5Bn5nY@V~;HpMHVQbmclPUD|GHX+|rkkyCMW(h606;cF?c|5!0exRi2@PVmd z^t98>KKdwc=0$nuiN5br!TXuWnRfvO=tsTv^ls1SM|aUvLfg}JPa%S9YLYWv(@s4S zg`nbdA((wX*wE8f2?8~zIoF@jM_JULTe~*l1?Y&!R(M+mdkxY1O5RKN(W;9IH@;-n zm>l~PJRQZnqiszFc%%9?!@Yk(7VJxlAdz^oP9746NZ&IOSFFAo-gvi`%rZr&(eACQ ztLrFL*$^I1e%H3KaxWCbVQD0s2V=p`?>9+#$Aj-~YpEM!TK85KetxvWOK7yD`{P(? zlJe2dB8=8$XO}!)&uH#o3UN!QC!U6^U6nd!z+A5&ynHVA;wgoZB#ny3SeCbMtD9Ijl z$Uy5a?xY zY{Jtx1|06Br)|&tXUdwI`02YPdp{HWsx8oSY-YPQZJgUUZ>Q}AuIq7{yt3rA8|6y1 z%-${in$uUhsvXv?TQ`GBccg09g%VRJ7d!%U3?VTyqneDQnAklo=A#l=l^;v*BO0u( ztI8Cc1ES7Rc=k|7C6CCP_Wbu7f!$P>Oq7M);*(FN5`FTEio#G5#HT0XApNMqUy9@7 zN0a47Kq@B2P0v9fixe)I8zy=%eDFb8YnYzaqjV&u@3p zV?16`F_zkC@%f*aBk)CUo8j>5cQO4%GC{y)Xc8IIUHZsN-bTHM_*;pZi=H%YfqFe$ zQ9{B3vQ9&<-TKjFGW@f|D|!O_5V0JFeF_#zAxWxA&y_9sb$wH8isHZ!76|xxC-0bC zmhao5bwDhT9ImtRne~~AV@BI>Bm}K}<4Wdz1m|2aqd1Xlj^O%_m>D-i~JyfCt>&lK_Sa%JPr8_IiJm_;5x5i`zJLp?g5GQgrJJJ+kQ zOPiCj+kDrTt;2f9fHHDgyCm^<;V3H3hzO9u-CHhyX^q8kdwPOe?bq4FM)*79dXn4C zSY2W4FzH%efq51qr!-gKWVZ?F7nyJX6_ldkMK4QXjFJDh`b zaItLqm%l8)Jtws8X&-(HcX@`ozM`x62F2eF&~b&W&E`2gbo~9LRI1tQo5ep`MUxAT zU^iwp31q|0qR}e=(1?E!#=Z}?CBOm%0zbv#vxd)#ioTMRG+IM}W?<zBN9+UJu$?-KZds@ZkL@wmt1!taj9{Gt(Z z85iP~01{q{!(3Y)fu=1Jk#uwpkF9YS>()@BwVit@^~4C+dm>iJPbj|*cKi8}|2m+_ zC=EMa+nQ8Mu(0%$?Aa^Z;xb&%+P5!9%EqLIKckxFqSmPZ>a;-1UcHif$ywF-OZSM@ zz4cFpU?)h-zh?oNWo<)qkcbz=Bs{YW`88~rm|f2vl|Fc&=8e?ag}+EPulW&NYJ^)p z;E*j;Ll}!d(lGJGFRMHI`|GQG&sSbpbW{32#?_XdBWOX@CxAIu;7#^vsH+imhuOc! z(ed2YfHFiN=cULTzT%`1yPGh4LIuDf#h9qsPm%-T%94>bvhV4hc$6H?TEnC3$(D+K z_)puEZ|$WI?Zx7Bm#zt?ov(Z%dK|pKK5lo=nsV#7S5LNxr+eJ_UAuP0zb@z`wmMmT zEN}6G)RoL?NPYG4DOEXP9?o=dX_#bX<@?aP%v^#QmpPShHpWzNm&_$nPtmTGPp*GU z-nV)4Glk}DsW~>Le>@p+ubeBW!^4nAVl&^T8NMH^Gd6#}*#3ujNqPHUIyD+*8lS?- z;6K6R|3z^2|8LfKhyVYXss1y1|6g4oUNReKb@??NqJe>gHjdIa9I2ff#c#8T%bXXj z>B)YD&R`7!z&*H04GUk7c#Mp4q?~7Jtmvj_w7G) zI*+3kC7f)ZL50>>gM|Asf{h{77(*yYp>g+afB$FCjN}I%#qYIW9z{~13%`k!pT;{F z)AUBN6{)}@(ittwyJ{AzFhW<=xsVRb8;SKc-1V zUef)3riHKE?mze~*Q7;@Qu3X=mNhnJqwXc94EM}Qlo`$(`!hmo)$%A+%SQ4HsGP-T zu)9dh&of6yp?FF?+jbt;mMMGIRDRAq8&5@3W3i{N344J)*kS-Ij>%m+nc33g4mh5Nj@QUwv8L1bjzX*b-UjmWJ6fOE7`zFjAO;?)ZYb80 zg3I#C%5(&!J^ubj#0o@uyKTI}<(InE3?1CbVoj`2qm2%yi`mFj$B_%%xEe9Uj~ey8 zlop655C=dy=X1sH`OBl`BX@}*4iN@JQXfr{c`_Nm+-=wN(9 zBSfb{Iu9Kmcofw^IrAw$mWsw*k+K6+6x3T$MT!0YBFxHZry&@SDxKQ69f6MN#D+7T5M^|yZ0cLZ1DWyBd!ip%|~@fZvC zK&lMx?Dw%M*0bq~nU~C3tCl1<2W?!xZGDd&AG)>)QjnG{l^J4PAXyya>A-z3G}C)g z1&Fb8rY}}7!GW@s@swfwY~0Ha4YWQ~yywq{Wk%J;E%kZ;OoX=C4G;!h%agYLybB(*g6(jsUgNkBu z(;g|7h3YOwJH7)+OFTbubPJ*@KF;do}nVh6@ans=>u^>n*Joe!M1Rm{yVEKhxCHPb0*nm+f|cm20s zGYk9+AIHyD{bEm9O80AY%pQ~JNdvxM5|j; z)iP0kD|uQ+ySl@*tiBGKtZm2#u7Xx##cCihIj)G9Za*iT{K3bxFTvA|rD`7|!B?1+fzl;N&KC%>&O?)=A5g z&G$ttZI!4mwTujx>Od-byV_uz%!3-HIe`K();aaKnziY*uf57;o!88I*@P|h3cT#t zhqX<{9Q~lzQY1*0wssMfOe+suAr@hZB>>)>`W7zJh%t+#_{H?+NdB!;jGkUi^^)a~ z8xL2P^hdq1+xOh-so!7^x&a5oQWP%+0jX$3)fWHjMy=JOq_ZTs&9;`wortk4>Z2Ht zq}~&ezAAn4gt@P#uf?4&t}pDHSFw7KQ&NBCvN!jm@w-yq3aSvr0@Wx4iM}UAbeAq&a?k%j6>EywSEPkppPE1{9?oesD6P||9GZ$Z zWuSRH*CjNG=*cpsStCvuA~H_%)o+s3_0FRoGT&ArdZTw46Y83Ye0n?(eX!pKyyA*R zYw+R(L{!kMZ<17(QQjy;v`%C3p5pmF{^i%3k6_1L{o-_{2P+Yp8&BrNUV~yv{nWWR z!K8*F1Znp+GBQe=yEDkEpo6syy;!1;cU5)v^U|s?Q~B2U7tY=4rB&LlUswzOJn@#b zE(cf@X2puNlcT5V{_^DBRH@LyH9}Ylxacsbu%ixV%FPepYyWjKo@5p=?CVeB;=Ac@ zAhqc-?aS>xd7{*%3(+AexYYMbVIsoQU3|!GsTQJt!1SkJqvJ6TS2J71Rx)rQ_9$O+ z6s?bKKz; zcZ$~)e{RfpE9#HuwV8$)#rDX=UymPJt-A1;f1v>fKwis=^OXdtwO{@E+0~|4jcoAi zg;>_y9pYK~Bjdy?v9p!n-Lc_58=92{dOoy{i5UrVF1DU~fj8hi`dMQWnF`1jlEUo( z+<@3)0V;(-x6jasS#Ug0d7rje@~~K>O05FpiEvL$(oCVr8=7vv2X7u>%dzE15olgV zP|B!=o)mu$gCw?at}Ix7lcf=^95`4or|^bSrOr6$)a=1z08&!~pA_}lvUg3f)F^W@ zrc-a{W=|_TVWQxFZl1NPz=b2H(V@`LVh$D?ROIlN!n<mWqo%_7!=x;bAgv4d{VLn(?^dcrV!?OP6_OaVdxOY`S_0!}L z3_KQR4*Z=z$`Fv^$IfeSSy#DhDcwG3fO=aR6?MOcmvuz+@c|1LZ$~RS zTw+h@YgNanr(uVK45^QCr;Aul_Eh*G`m|k$rS+!c_AXQwcE84W?J=4AG@3cP@o65J z78!K9e^i{gVfHZ>mQ03mUV~mlnNWw(J~-_-BmD+HysDoGn|olxu3eMkXD#lcDnFJ` z7O*F}CJ~yJgJ&!4JvwQ^Ycg+ESHC(2!^N=)mERGraDUQ+7eN+0;DCFlykCQD{a`cO z6IpPDJ#&)NO0yM@XQ*idhLBe1UjY@g9=}Ny{IF=XC5Q4r?GCKOF=yksM;$d=nLrUz z0&^b3vEID-3D;yl?_e4N;H*8X-uV3a3;z6)E#k8JKr_|I*+Ey#%F6r;n^@h}Pd;Z_ z;1xRyNW9}?f^#J;8m8yc19c`nh9q^9<%AjrR2@;Y&6J~d^gg}7V`oA_Lq{u>?*!5? zx`}DTL0tm#KZPwN>2g(Bu4PUk4fDLtMx#1iy?nXngP<mSovnWp4%}E;T8qgN@Gk0k7ag#yk z0y|hX;5?fRgHLRwbSug>R>n=F8l21I0`kS!jP0%k$T4sPphBqO~T~i<64t1&ZkCSH8 zkjAvLO4se_^&)5udA5g7hX*$*ww!XaJ~v2pl})W2qn#_xH#ogpPs5q^vAZ%Me{uQW zv&St$Z7O<$wK^C$fFfTPWbvF;b>R3y4lC3tsdEiqC`=EIv~&#g0r#Xrk6gsFi*yYX zJh=AN<7U{@?Fh*y$L~$+Rq1-;{t%lPdYmoGmh@8n_K3t{GWnAn*3!cY=4I91D{DaI zZ=Odbi(ynAwWhKOe z`&y%5+>SEnOB)41Rteij475c7W?m6pgjgC+ zOM&!@2j?=g_i8jhQv%oX!L1rzDLZzl^Vt)iF$QZI3-}p7UcGj>=o7G1DPso2g>a4? zGA#gU6jir3xIq%spJY)qmCbVKkU@2!mZ1BDWlemW{e#biJ2;fBFoOVi-LM|fSwu&4 zmoK_>Sr~U<)v8s`=-ZM;+!4nasZZ48ecB|DQ&JB!Bs_{7lWv`gNGisvBc!u9z;2a0 zxK5bLy$d!5X&fNX z5Y#i4j7veagkQYPB?N`|!!Oy$Qm!KYKq+>`m|MkzEjrNf_Au2S-}rfBUlvcZI4=3% zQ^CO$QymkK4TUg~8ZF~$LY@#uQt6Ld_!`wPkPd0;hT2(D4H0?1d~?+XQsou5{0)Y| zeoW&+s$h3T{uFT;!a^DoOD`OH|D>2Z?lY$f<}_)5ls2_!-aN&5yLl*WTZyA$d}m`! zBixp`_E?~@$^c143N-0kO%fF&dbGA%1j!VY!*y9 z)rW!vh-JF{y%4*v6vR=VuoL@7)%HxgFuSNw0^z{>1H*YO0oh^-5kmk)HL|=|RFFis zS|;>&3#B6s8I;r`gI@Q<#Kg$))a1jzOm~)oZ%|`|j@STV+w4Y<&RI(HNu_B|cQ@w#6l(RPVWy7tZ1gm{{}1eE%- zsykp(8Ug32*MlvC)xw&Jmctm5 zdb&qpJzk)#C;=ToBgQZ}E=&w)&!aS+sjgcggTFIkPpS9O0kvhMPXA^AGeRk#3CI=F zX+Ts3??b29i%jO+@mbrpF7z7iNc*uQppZLO%qz5hUlF}FLFHSSYT=!D@#4uZ%8H8@ zKsVUus$b{t0}4{&aO1a`bJ}m%i36Ipidnjptese1INt4?wz$)?I@ZiA%V&6nkr;3Q zK38;SoI3LO$&)$s9=**}s>3I!=XDS4aVubPyz3}K(fXYlK;T=|Z=ZG@0d?xfhUg?K zcxm066~i>2eg65JHUaiEBg6!t-e`p70h_LIODYnF*_`#q(-cf+G3k|ga%sM5$6z@Y z# z5U;*Ja;&}m?oY0?6C3|09yNJY)&F75HK*aj*~9!Vv|g6`qkK~9ug4ER{`cAI{|iB~ BGVTBX literal 0 HcmV?d00001 diff --git a/docs/docs/assets/images/llama3/lingua-7b-mxfp8-multinode.png b/docs/docs/assets/images/llama3/lingua-7b-mxfp8-multinode.png new file mode 100644 index 0000000000000000000000000000000000000000..a2afdbb77bc94916208ef21f57ff5a0b44e41c22 GIT binary patch literal 227957 zcmeFZ`9GCy+c&%^Du-t|z~K{J}6f(%=#H_$epL_T4YJI9Hz7f8ylH^sQzi@6LF<`~8NS+O+D_ z&@H+XCy$rC^m%uKUyV*p)n;O;^Xx!*t#$ef^tZVQV(*wq-OmH68kq%Wi0B zs8h``mtZ+N_%$uU%sntLkdcY0shUQZBCYx zlT&`s+SSQVx7@pT@4@ru!S{DvEUm6)xqkin%+jLAbxW>_`s9NOF3Z#H z-obAcs2?8+h*rCVDykejs9;&}_AR$&K3S4&O0pg5O=Bqh@}7ew&n(HkuZC9=LjSPolE-?dSR#XMYyB41CWPefa26aHr8~73TNt?aJ-F zM!AA^Dt-T6+VUgk$rq=&#Pna@yYF4%OP86IRaz4(F zFyqL#_e#3oMGf4VyQL5QbNKJ-JV;A3Y}DK2ch48^TKw}TKd#8qI8r4}*<0!A^q5|6brSC}r0od+z-Ck0%?0xh3l#Yw=Ru#vXo{yuaIS zgOih!@N$xr`HGWzdckM@cvoQ~yDiu4rD&7O&3bgQk!uqc_m*wJmEj=l3#W$k7cZ)x z?q1fCv8pY5>pC+2$+^3NP24ccFXGhOD|WJ%TJwaTA7b1lB6Ggx?(28kwr!Jl=%8pP z>?yCB`1?zjZ_j1MeI8Ty(rO+&c;FeW|Lo+p_THeNpr?Az*L{qZta~1uzKT!YXT8jM zhX&QnYtEcGLypxwNz;w(h1cFcI(%m}7ng`j@r`?l$p>#f#EpoT6%u;1^BLPVt)-G=f-ah z4J!Nh`y3UNVPj)^b^gGeE{3c)NtIn*(b-6GQAeB)+8yLKAvOB z_gdW3Z>Q+u;=;-!wg1ofx!!?+4|^+R?OI-|D(glqWl{II{s}n}xxKTuI-24|-F0=K zUeN98)u<2sYtSlE9RAZCX4S~iarymq+u((TCklJ_a>w)+U#F$Meq~V^rRu^cXj%OI zdo0J4-DpRNb<68*FHPQkta2T0*%iyc=ka$yP~Z4`ZG7M}z2^^4JUUeW^{eW+b17$N zZ`;_|@F==TiE8QUa@<_SuQJ&9g09IOhl7)gtM{4QwvD*UJI1cljyt!!b(M_TVN}6u zCSkq**nP|DsLr0M2)C)BSm7O;H?O$Z^s@Zc8lk)4;l9^t8O9rp3NCLG-?77Tj#W!b zEB{9U_4n7b_hQv`b?j^#WIJl|W5U9&V=tJE_f%1Q*aegv3WX@v7cX{~`m@B_+1S~I z3LXq_7%wj`=Sr&;JC)wCP=wl9-`FUj6PuO!`t`k8x$P2Jf1`Uv)&|yNO~j3Ij)vzr z^&a_8+w@w?sS}B^G^ zR=-EdYsxT{9L>%+lkdB6-__!;Upb#Wd-n8X+S;qQeX&K=O-<3$`R)fDt8|uq3lCoE zuS=@zDi8kr=rH>Nj=Pkc+|Ap!SKW(@OwPz)zj*QD-{0R@uon-XIu#`zm-iuYpGQDP zhne>MSFfb0@7}TKX*)PbwdB~baBy%uO-oC@@a-9of|IbxyQ?Dl@0{$vP~}?iE2jn& zE|=ZjxR05gojoQdX1|^uSEQJJKt_fnE>YNp-zLeW_am2^hsXPx8dhBH>OfU>4UNth z0vk8prlF;OaAmlqzNzVER@P2dR#r0$iyP<<-T1mX%6kO`uTRew78d%tv&b~{4mM>- zU;S%+^7$#Y3l}a(?cAw?-JNMsxr(Bels$^#JW@o=wUS7B{G-O+nb?$}tfgk#3M3t45+X}qA zyk>LQ{bz@d#HkDWqHIaNlT~>(HQ2;{`}S?MRp0q|cy5v1&~WO-X%UyqH8axFF@b>^ zVM_Yys;YiHXHxO`GBPqVv$MVPjg9YbvX6}vEuY0n&qO;x=P)xfqgYy64h#$wH#Dr> zxN+m?^z>oWxx%F7sbM1=NdawbZK++m{0j>9pz57?byn?bhN(cp9+%Q|!)yoaM$dyg zmE6ZyqJ8_nd$;fC(W4I@K3uVH-#&8uaKsNDJsQ&bL;Yx!v%Od4NZUJ)&pCERJiNTs ztpUNQPo7wGSA-6<=F5nSi$6$BrQ5Y@*U6V>)+Z(=HsfW1Nge`9o*(8fxii!{|G;^Y zT*h5`U*$Z}rxmhZhCY;6E+9Es@Z`ypXb?;+EG#IfJnuZFD;iQY9fw=qtX;gM!pg=b zvb?mkxVJx4m2&IOoqp;8{g_#Gz&xc3weA3jX)GZEc{`|9StdGls@@`s!IW9@{Go}Qj=zq*vMP*+Af4y3slQQf3< zj~`!sMqhaCE5q#bOi_S0rUP@D^Q*GYf9mReZu~aq>&U(7x{8Vb{iJZIORe&4R&|N{ zjHK-M@25zPRD|*x;*wLn*zN4>uFU=YrP8En9oIMAKQItuzpv?qk-phJ$uoJI47ii) zPUqOhb|^)~$BP?E;tb#hoOwLwI#7T8o%eF$^!Kr`F)CF_Ny&FoR!*+JM@CMLyT%o0 z;4${<8wnsW%jKncNy&YiHtm00umhbBO-^Hn_ow^2d?(jl(0zF(zceH?G&Sm5x}i)| zczCIprnSHaBX7f;jN=L8^VL6TW=xg_M=9rqA|8VoG zgPByE>T0U#`t|EyVfC4#oX2}K9=~|8!`bl8#(iCpV}N11ak$g^Kec+((zBYOn{uQF zi`(55?>9;AcK!BD4<}pe!n^T{@7G%l4Saqyn6?^AeAa1 zArY^*>%!L#@utt8&2iScdkk|g@n^(;d#+!t=Uu_`qww0^jC~PThJt%@mDfF9-chY^ zW}x+JuD$70f0B2f$5GD5?BS-@#^)ohtdiNatGCA}&%yk95iMb5M>f9eB!G3HFG|w% z_6wujBgTobv4LKTbDdwF9HUzP$hEIb(TKPk5h0W(VZrL$S5tz!9Q>kCL~!@+b;Cb$ zB|qox(9zX>W%Ff=HW1(jfc`DcfG#z~sN3UE}1RNjOk2EQ{2Y7=y032xIZk~bdy~mAR;27L&-RPqoSM7)bNjPpjVv#XMq{jhCO}Q~Y)q<&ZUwzpSP8 zNu}ufnNvGuWZWmD8=vX1;Je*K2O1h0s?T?}LK%PG(y{?9rQ-Lu=M($R_r3yf5jQP7 z^Lsobg!S+65AOcj1hIyGVG)s%FJB^DTRj)&xX`v$FI;%mYH4L@ne5zO3)~%f^oP}# z$6Ba-&adpNeUBz8t{2sgzgwVKwA-`krSa(K=v^DH9JIqnINfeI^I^Z^lBmudBBG+s zZQ~sCqn)&+rKR#VfL;DCUrJCq(4%6|PmcXu`ZYM{+!o}4JEfwcLRo0`cCO2ak;_P5 zTAH^pF=uAKc&m!5y@2OkW=*Zs_adLm@Sm z(4uq;3k%(pT!o24f1jEbDWMOf_OY|F*0_3$hb~Q{roXcO{Lu4GVCmljK}Ky34Qtk{ zvD9HFRn6fu-)(H35B7hsKaL+iz8|-_b70^WijjF|sein}B^@1|Osl$$tQ%z2!>4+y zxkGRq^b;c^x#-z9hB)

;#BxMk&57`1^>jukQ-n_nUp6;@$&p^nQ%I0SHm(wHSo= z>jb)skBfKx(^`TnayO3&ss{XWqdnHv*5=P_6BQG~Xl!iEDxk!NBeLJqQ*mK&aSV)2 zBQe9Y=yLQ#%CWE)xBejD#K8d85&`dp z>v3}SXK?}Xaec=U6-xk>JAePaZ*SNYHy;%h1=4dPHg^454vsY&H?E@d7pSY4#NNOM z^`+~hPq6au;q2?{!@|-uzBI00Z$`Z~zq{f#e=v4p&-|>KwC9Wic1~o(V-U`P5Q!2qhN=QHLDIy3|S*Buxv;Py-!7-oQZ_ z1Ge5feD>^SY!Yd(Qd#SI3RwRP*suaEqk@70kF51-F#5_L=kXV|4c)VoX)9=HX$J-e z!3$_nwcb}nh$_e}d}(ONbQ?Qs+njkId>t28Ig8j}PtVF@$Bx~@l@a=gqq&BcH<*Q6 zf^pSR3F87z=mzn!y1KdmI|Lr8A2|ZPLJ?K8MJE6}(tGB-=rup3^>X(DBm_cppL{y5 zqeFW4XR5pNE!EBO3eNO6bJ0Dse}8cR#a>70;mk|=2X-GF3hjEmy!Xmb$EW@r*AWvF z#@5ywk)Qxw;wb0?sBb7t^k7CQDJdNQxp_{#K}VAZG(Q65%O{^)TwMH&0y5N^AK`q7 zY>$9|L&?=m_4WJB%r>2VZ7t}si3-3zIyScL^7pRQe`YnbwF43pO_HnWy;vQB8O)q7 zHfIX+>6d=51l3OiWC4_pD(S zwY+qxmL)&Wp;M)#PW;#1XW&Z>a$7BifS0NDZa9jwFNHDyb5Ern;WEj4zMxXY@SDkpf8W)65%z6Vcd8ynB&EOStwX1;nfy+}(pVZZ zjx9$!dV+(3x_WKCXNhDisn{LC+PWpvgND{tTUl9ER_UG*)MGJ9RJ7!lOTPPzZc*#L zY#`8_M&F2FwJGrI(aA{Bl;3YIHuXH?pg=guRJnBG*W}7#mXSq(=2I`vD3}y-h-y2G zb6r+43wedx>iCfmBdgikguU*=N3UGTF*A&fjYUrn;h37mwH`fo?2hgs2ulV!I%}SjX=rE& zr^jhFY}f#1hyT^r z*LS?T9zOqKZAYb%0~DC3IoF;KYhr&p;Zqf_&E1PY89p^2_gYa=(d}3lur~*aWN+_2 z=f0!KA0@P>rlzc$Uh<2+a2af*db&?~g44e1Buzr_F&(J3md0GMx7yEEX&M@?zu1_5 z06PUVqugU~<;)s-y{&%9is+4%djmp4Llx}Xf_wKZ$;!%(&P`riSe$NmaCW}o?@ve6 zRa5{4(H`S|$S3NWnh^`_t*!Cd3+VhY(iZf<lW98s5hFj}y5@P#W_4WI#-w7GC%))Q?Jq;W78DY?;rfU3YPNRcM4IAw zfZigwmuQV68m`2dyyB`aMc&@E%M(jWCr+Gb&NQdJ)L!^$@1e_=FL!iyp5W7qt|s+z zd3tiP9HNuRvi*;o9rglZA?fMjX@WKA2@by+MDnX8Bp6ol?FXazoNaXs_1?THTm-uj z;H!GxPT6bW78v-WYx7go0Va14K^81uIh2y#m(P#SbmA&CK1qHpLUrdG$=k5=oO<#@ zYtB=S(+7g$>fam&B3b_xdyr8u|3o_RDSG6r?Csgr)YP<3x6Rs8Mpb_AyZZN5wCITE zHp+vSFKG_W^YL+=*Z>FC$T&!PI1;6TgUi0o=l zO|QHuzc@42?Z3o|@?c`E0DT)|yEZ4%axZP^d~{;sMyN=RAf_6pA;fldcb~|;^i!p2 zw#cj^WPSKerowCU0f+BwGza>OkNXJhav;Ox9nlGOW4lJ38Vn2!J{Y(ALJwekJU=(* zM8CVQCiVt;YzO#!_-lMczkqx*br{I$(tjKJ%MunWxX?eDxaehXFmSy|z+1B(0h z5jlg%))a@y-+Hy=A&EvJWLukDs30X*%dd)+%F}Y zcIOWJ2#p`0iUn>4^q+9QbKLAYG8IxTzfTVQ`c?IH+R@?j@w<}5G@6T@+(+V z*~a~1f0Wc6Ia&g;YIbO%r``=$qx}y?eu<4ql?BN+IsGTjk+QP?$xdnMH47oY_SJfF z$+OqNaGB58LrM4FXuHx&@r`LKzD@v6u5-ih#us8UIqC;e_N7V4a>v(4ETpQnT|1)Q zCdWShai3&#%RCp8a9dhXsHv&Rbne&IZbMI*Xt}7VsmZ6{v~qEN zx)N8V0G*$_h`1fla4eu?5PV56wa0_!vRk)pGY7OKb@&EutC4S^b`C|5CqQG9#_OjX<#Gf2v7;l%6rxc=fpQf9b2A zdfAd0eI`x?$hNe!P-Si8Siy*l%-mIyfCs8HBSct}i;9#bUx4-*KoQQdZ@>NE!6q=0 zlKT40;;f(}1H%(Z%BO?d2&O~&nM@0h~-iRJ!yXZ(6CmnS-f1BK5qtm zeSQA2qkw=)iXE?e?i=d7zT?KXMbEL4?@LOqM{nfj)~A27jSqrq$pOb?xi-Is6t))b zDC&Gaz?i#S?2a8fq%Ss#uUxrOy_R9e>DOB`SApByx_!H1a5B*$YZspWoGUA?7H{a9-#M?5~QGTlw4os`(iF4L^VX z&N}5~Rq%}C+2^$&>pG50R}e2?VWLjCd|B-r=)LrB^a+dRLgLWjzmqG1_)q_vFbQzftj*7WzZ_KNu_{6B0gYad&a`X^4Txpe6gVW<$~gFL4H2=3UyrvLIx=ew~t-pfl} z^DF$%+1j#7N=jL){S&Z_;1 z>g!c*n9U~eKGV;r)#SLqP~^QVp}RcqDk5ub6z+VQh^nf)|I1!I@Z$8FuGAf-h4FIw z)#yU!m?O9RrPner;OSpRW8Cf_A(Nad{Po+n9pJ%5~#+S0kOjY?qPYfIhoD(?q3n%k=P%$~c*E6n928wjG(&hD?wq zljbiliZuGfIWD&5N%ym{BwI|wHHts*9edjXAG7nD(1;Qv=fXN`6_pk6{Q5Jt$~6Z6 zKrLLu&;LoYZtwny>~V0hR62TkC$H|2_96@DNmNg?9*L+utpzT@uB=5K)2mR>dU`)y zTj;gB!NtYpexv8`ZN4?E5%+&UR?D<+SE{0%HNj!0$|XY!6B#}nyl#h^ed02X|k zBTv}w{$~(!HppEHIsN8>V*~0>utE33^>B0gmx0;1JbCrkk$p0J^W(>l%>^!EFo`}t zIVJ#;&nXIvoj4ISlCgh)CNN>I>-g>YfdOkp_i?|7#}LBwd!YBT{R605WR>5<9A7?L z9>g^>Hz_7+2U}97-E%>&XPBMyal(iBzN~SzhJ?UTDR$qp z0A+70~JRZZo5Tr2p~C4}W>W-XF>dJRy7GD&>#)C(Xl71RNerht6~u?}88%PgB_Wvq;q zFfAK;AF-)3p$WJBkYZ+L_GQ1s>1ym$WPZ<8Lc62#K|Okgf(>jbII#B?+CP0_TeWIc zJWVv#Tz>Cfn2?%Iwq=!-&L16JT-30@!GnVjpat^mJjb{}*4n46OqFyJ>f$p9Xz>b8 z7r#9tez)dQQexsN#6o6fX7;P8(S82>`JcpCY3}5-6KI|uH4YN~1C$|RQ|9FCCa5_t zPu{lWbzt`Gpde0q&Mfa@MA1_`Xy)VlI0v&bC!==^b2M5BoE774sYfrNbA#iPNk}e};i=!^y76 zi;&3D)6?bIy;4$BHI{b3zRfUIE^)DM&ASd(P+C^D67QXvlM|4Zwv7-zpk9&>5j~dJ zlPpw=Z-3{(4I|nIuoowc@C%|GHl2qe&a&!Mb!k;qAcq!Fp>-yL z=G*fhT~>4(rN9eU`){(-N4@BzKnSB&qmxUE^E|t4c@V4E?(SAsUarM8tBI?_DTubhVWfJhc%!*e zU(d|AnYjPME8=>3-&Pd8hcCe_{sdn>{Y<`7yezUs-@R6^|F<*h_++i`)(kSNTshhZ z*JRtMN80n}*=SZ(#v{P9+%xMB95^uBD6?skkNzcaSji;Z0m&rsd1Dhqj?$sF?Hms7 zQif^_hpa~IA~a{Zm(L$OC`QdDi++e(9iLu7E{IY+M|-fn4%#wJ5uR zp`jXBt!};0hV{|N2ljf-PlZCi?>9OHSE#2=seIc=jzmrS(-V)}dYP`V?_bi~ncQv| zrJOu93mK%Qx!Is^=k}9Ni-9kCp5@l(ABmbecHqGE1O?{+-BaLQ*9YOYy+VoUeanv` zb#h>MV0idbn&zQH^v;fs+b%1p3_lRB%)fKz*!3bi$J|bp-m15Ck$y$XOGX#?5hqD` znG0;A;{60G8QDC&x}~;Z&bfhc@v+K$3($`H?1o(z%7|Iw4tQTz{X8S1&ZNleQ_;NjV-`_8X+aQrn*>Su;E zJJBmv6m{dy(>1*~p1eY|o$5Z3a_JbhgWtW4%xOmmm`+0;U<`t?=Cy}L2M*96rqlQ7 zBHaSx#poswH*<*3lN03hLrmv1$WUn5|4?CBOwc4`zmHf98_c?{{H>zRt>n~ zZrZPCPEZYK1H|VkFP#`_UtV5&7HlAT-&y8H3Qf+z{MKsryhTH>O~Obn%Fi2+xiDI zIF}60&J$9KJArFD`@%?Ld~#AvUHw)E^>PDoh9(Z#X^fAJb^6z?tLAljvlSw< z;G|uZsmA5YvZvAwN;Gu!^vWTuB&U?I-};uS$=#;-4UC^h-oY%eFrC=>$hfp)v5Y#Ah$Sx!^ z39rQ*A`5h8bC9CuB5!55nq@}E5_xUa^pG=aExg8ymDvwrk_08-`IBrVB4A)GzKFDR z6yUQV=IwlZG}ad&Sw%~W4fNC+5R`Cra38pyr~3+sc?O(#05`?l)>hojVqQNP`WG|T z_TtZ<*T8fw2h8jk88O{*dGY|^-T*)E@%#c)8X>WV_U(nfgA%7-GZMQuIG6=zkKB4hT8P;S zlo(cEa!K#`sd8Y!Ovmm+gxwstK@UO37i}7QJ^9h28=$j~@J8~N2Ue3Dh8q?4)Evj# zZDB_Gd`%1ukO0|K=;y~zoVboe39uPcxx zLt=+IH2C$EeX@i644Z~H;bv#gmj3cK&2ylEH?mbH!HXU^G7o2wY)){niN5y8Eh9M> zaWuo|{YXYz$5S4V`F|mp&Q^Ft(G_&XCV>9lTQ9m zknF3I5EW%0=yvjVr0_7DUgDguZww+AaQ1J25GfWZb5p~-2ritEuWdS-V&JOXq`36-H)NV0)Lxd~w~rsRNc{bV#rvx^ zDt=gRn+`9dFnRpnH!c`p`OhJ<@i~CmhPgN01FWI)f^l`v&%3UKSTf9nZAa6d#8C9QBCp>>S-Fs zjvYo=&iBR*mKsPar#ILNFK-Kxjz{c>q!Xkva&zU-zBCJU@N&l4G~iZ}{WK`@+DGX? z^=!_vWC>Wsf4XAV{4Uf-Cwkr=uWb(A2*s;E91HkQq+UlHNv}~5Kpg`im6fD}L3@kY zY00d*`GleVO|mYQBnWDdd&$^PzdO3;&mZm28?fegllcBq(^acgy;0W}CVFrlp=agdY#T*Zl4-RNia*sMD9NAWToxO*-ATsQsvR-zeU zEFjWF&+9_ARb7C(unNWkfJ+0Uf&sV@^aKxF;df8$>49S6H3Q)dzkn?2h3q&W2 zj^-u-)o0Iwz(d|wRWYGYiDfPT2JRPLi{S=yHyRYjRugM5UM zD5M#f`K9PeC*X@Bjc^c65XHYT)BOGD(07@%q-ij#qw(@<&}|}ZqXGK55O6Bbx{baM zG1!lP?-jy-#_c%AbxC1GFI@r9_g`oQ5lS@(F%W5t8@~xHJ0kx_Qi`2jT?9tp{5%9T zf$okF$X+Cq2DdLgu{?Lq45Br%p*P^_pmUx8V!S-kwy%AHje;ntLv1L~eX6d__>Z;h zTNz;LyXYOk6=!7P%)7a{J%FxA?i{WU@g=n7I{;xNljBpG&dZUKrtpS)dwR@pkMOYp z*RRt67(p`*zy);S%0wo2G0oO#963^o^ZD&V9Qb%~*f ztCuu1=r}2I|J|essE8QSN(6?UK7INC;unvU8O4O^R#F`YO9xW+-(O!C5HzX({P{Yd z*M+8+p(vZ(KsM!*lm80sihCamR`{(pLw_RLAsT*tn(n{;3`)Hpb?r6+=}3r~Vfw(* z_)~iWuI52{dI$&+i$oIQ3$PX#&{JPuY+41~p{(+bfaCm>9ST|B{M1ia-BO75S=-oj zR!2*UOmev)TPb|Z!fSbnp9FajwsC;X+zPro(f6tQ>*p_DI1un9bsOk|$Ab75ON-N{ z2y$>B6m8v*!qhJl9uYyjCDLe1N@z?>Opr6Jgk<0V|3mfAp@-)01Jq2c0M-l;|05U+ z#QRvq?%+FiEBDFNu8W-4gZSFFX*bG?T zn?RsQ|KEW>)`jhQ>(;Ge=&maTnm36^ZrjF$$YnW-xVeP|Ev9f#IXm$?M4m=m$lTIW z#AOj~MF*^y9kvY#QX$&i7G@FF4fGIWs9OGjklGd2mi?$ee$l#$1w%XQ-_8zN?nvg9FAa$&#r$li!;U@H`67r_^ zw&@BlhNC>KeKx#I+Cg5@YAC}gPB>>r7WRJ3R$j4n%sFpnnvIchtCL|@cUxuUlMj}W zuTcWBK_Onn$Mb1lJ#vH@g=`~jF~X*`ATsT$RBvbyTA8LX?T*L%0`ELAH8&L#EBZKj z(Q$*4r#x*jxxGe2WTt=yR5&(3`L2^Np?*LHg~uXfmu<6D`@~0lW8rBTl0z^AGTu$^ zgP-61l7sP~Oc0RTU1Ao>FEi#R2k|CA^dBZ7S$g%U62i-DMLNsZSC6-=3hFU!vw-mP zp{nX%3J_x`jwHIwvFl=^>Zw!fus^oyy65-zf2X?3*xK469 z8+!o66qIYA`IQhZW>btZx(^g~UxZ}95KQUm-GTQj!2n zeRubW@bFQDb{_)uBj-n?Rcx{VLacz;ciP=qfb+B-c{?)wWp4< zF3j7u=A9^%T#l&PC;1JfQ_8HELb5wRbJmEuks%JshuYeThibvZ9Mp*Eot^0K=WQFq zjYLfDdSD9>z-Zc?;wEvg2G&eUla14?u-MP#! zHgIsI@y$G2&YV|J@DRN(e09ppmoI0dl44?R z1#*a%B07w}>c4#>Qnj>s=}KY5dB;iX!91NGOa;6}Qqz`yC<`rInOsCmt{#hN4j&~EI7_VQye(>Z819H`^ zRQGJv-|o_UpjcsqDbsADO@Akb->^3yVG|WidhYVKyEyKQSHTR}01kFsPfra1 zMA~b?t#;P*10dID#4oUi2&Vh3;91rgF38>^mrUuy8sBOp2k}M3>Q3bO;fJq!T=hndY z?X=}zv@6QX%>nXI<;$_x;qnz3hnX(#m2~7e3=sjqtF*b92g-IaQUZ6w!b+;Dc@F+3 zl1YPH4cccisC10dwSuqeA~+FhsCpuH598u^F)+4$*WW)x@qp>OUGLd`07fLzp6WhX zMY79SpxC0OI^y|Z)=Lc$baTwwppW4Em%z@E(kUjp9=m`{(KX~cUcRjA;E?%sJk{7c zs)sMcKi})qzRF?jJu@pS25=<^M-QO_ku-SV*Ap9oR%j@fF^7?xj3$rv)RL#-mDy|9E|bU-miqAOpeEC=e*&<>a(o%D_HO22vq94J=mNJwI%( zZrILLDm1v!si{M7k%`g+B#D@TpoE0>g;|v{XVvJF%PL0`mkyd>`ESBgdiU;K3cG&& zE^G4<{&iN%5z+zJEwUFIS@-POLl6#HrIhQRGcZra#;c92RMdYxM8fdiD^y7m=Z5D^ zMq9xfM5;F`dESCM^ax%NPP_QYeD9@MQJ~FLI6`FZ7ye&!zO(4rvu8oTX<>Y6v_XM~ z4<)3aupS3}0Q=vh&@&JoW+!xIMn=Z{s3`Ki(?j1maQA#*e1KLn1HS-QbpT^iS==y$ z6lD(i5+ykRV-L1nnSIKwkT@BX~5a^)rT4Xioi6%~ET0udfs=#@W zwrLP<{{Foakjr-#Av|LJ;@^|8({F`_`eX8eVg}8C1oa?C1`$W(m&b}?;LEYG?!%B+ zOFw*2f9!$|aqpg2{1)hKBEOImEd}x;)P|p*zwIiPSq-!R&c0B-L6{Z4G*BRtIT}B2 z7d<0kb^A6x6ykxlcYD!2Zy{Fr2>Y2JHdG~l6vhX51M8m!yQW6k?_geJKT`ZsD*;lo zQWODjQyxD)3uTQAWs>!SxPD~(Qp;;51jGE$B3g3o*|CK@XGUotjoyGU*%NTVW_7Mw z*R^4NZhn3?^af89V505e-i{(3f-^+}9kSZ>ClnT;?c14&DEap7+nzklW5);`tbc8z zOWa2k4y0t|J!hm~NI1gV(0+QZqf)x*nVwk6lP8Zz)5yuG0PF@)IAOMrIsEQjAG859 zv~|{qH3InEgfzTKSa>B~fXQ1SdRCrdLs5*&AUE-k7T}hHI|%OG%Zsz-i(n8wh>S~i zRGJ+?Q6rt3tig{TytsoHEnSb(PLEt5DBBbD_5-Jiek{Q#I$V|i28|m{vjnX{8qrp? zfu6UQ2$#a^fiGTLe_jFN5y{G&sB~=uPU()uf*^)32u1~>Y;@vSlF&b7f({7JfCn1J zt_>*~ZsR>nWNh}=$bxoILu1b_oU@wyyU0M}02)5=2?@M6E#$-0b{V=iiQBbZ#gY;^ zocOGAa&n-(bWnB3dJ|m63XwO2v!BID_2-dUz{<=4S|pF zQE$)!Kcr?;_)zgF%Jiy82100%26Z>d+Gk32yAGOW4kHs_q5o?qylz7K5Z)$*A4L8v za;Rhg0H*-!6HGNeqH|!l2aXtt!zg+$@c~&8tsgP}l3zwafp_Mi5pROpSb3ZOJ#rMu zFJ7>Dc+9H|3tsqX#D@njA*+DPB-M(6fdN}~;D!I(aAx5Sm?Mz z{5*RuKdrg;Hz^Vz^U9xACNh=#=O-^N-)*9gjbshMAsjZ2Q=R))ozA_4v3pg#=_4ej z$;~BtB(%DkI5~EVX7t~atednt!PPgnD;E%|6Awi|P?8qu0>{~LOQQ9Y5(FHDyu|U% zlRHANA2x$U;xcX^B5L&=;tCA~RaI>+1fN6Mk4s${tq;*1tnYxETVBU5catKcqI9XI zb9m;#bv&#AaQ_y>hhiKD1V;!jhVe_-JlvI#FF7v_v+0ra_l3)c3QvcsKon7I_J6^F zVGqL8ab`(a1arp0__~`xr*XH5EQ3)Zl07pTfsurU$b^jRXC#@)>XOC?T1Fm$Airmi z(>TeE^6s|1%i5uzZmyeEC zDkB^m6i!85osqa4JihOwT(rR^b?x%E7e>(}b)DOWZG!>2 zZ{%aYhR9C+{yEo9=19)dQ4nO&u zj5}uGHRMSKa1|;r=3*&fkFCc_$BkQrmagUCaG$L9@T@2+Q=bwwTn4^4*Rl@q&Kxf1QC;2a zviFCg55>F(F+Oi;Sx(8sd2@{N4+uEo^@M6JQ(pNJ$q$dYzuPiev13YdHi?M%e|u>h z2XWA`WdOf*3a%r)k6n9VLNCLhs?&d}(U3!a!AIiq*MuncyBf zDYj^N)Je4C4+_!n*`jI807K3qp&C!c&}{tZTiCw-ZOS;PmUD^tN%*i^R~b`mW0<~; z|4DwdV%lQT%QP^tu?_Fe&1$|)=?Wzir^_*lTnCLAB*1tFYXma=myE|Z#$gdyDMol6 ziHwv~NBj-7&1-Oic%JOwgg9z}jkJgnhVa+zJ5ny3_;J5Ac|^jrum|1yr7 ze$%6)9Vx*FO|~QM0O7zX3MBCkjTsD@&Xgjr#rt0*ZY98qf~+$-Ud@ZirR+Cvdi}aI zwX`Dt3SAj#`()S;TuOCX3UqVzkD2l)GI z+}>fRrmD(0mnU|?&A5W+BH#&aabt#Qol#WoviMT=>(}SP!osKvT}sowbS%Pl`~MU{ zB~w>XSeJrc!s=s|)PvF>$RX1NjdPsW- zv|s&qoQLvy%q!4ta$7Y=W}n7RoH-(tasfgK2EonXWcMVzUEe(=nVpqofUs+9c=+1A zUJLS~rie5>$IvE*)WiXz!zECa%^Ceq?iiXBvsk_;02}fN+wEG^_dK! zQvHsRsNvkGp`lUJ(sG~14Ca+Brf_Z1G1Gxbjv5;$d_r2XcqcP+S^*Nf^gb@bEqV@Q z+G0H=O*g(rMClX>R_iW3+!{JqT6!jejYrCd#thPD&)3#Gf}!-reWC+y&&Vx*L=Uoz~F=ZZGKIs4eWh^TpAXnm)5`b6bg zpau&a3NfQJBTgb>(^Xr8Lg+s0+^$%q<@q_uqh2T%Qo1x{2!$!+)iRs>5@Lv1*Tz(? z&Wt8^nB8x_D|nEH^7!e~4~PbWUsN{Z!kbZ~V&Uh}79%9)mhcr1v5+E>Ji5m-ht#~d zeY#lc@yG;Sr%tZ5YjvKE2rmu$ld+)YB=4mv47!}#J_xay#G;NOk~1AgbV3^CzQ#39 z&jmZu;c|KdgQ@)Gtqa=C@~Qc zy4MFqMMOw@A(MBp^vTZbdDU~GgKYX^t%p*jj~%yRl}8npLaj&V+9Gczu5>7cP5#LI zB1tb%9Qw|G{KN>e_@Ij8jxDAcBl|SeNH*ZAo47Se$z!Z0z~|GYtIPT6C#clxUp5A!y_~uO$KM z>_i`}-GA%^J)b{))@?q8aC>E;4AcP`N=;qeoruJC9}CEdq%t2la>SQCwO2$*8^SIf zLtQ!TrF;hc@re0Ois5olL6QWXsX%jg&j&pFLY~aYBD7gaiq=^Ri|;-(v@zkThZ=XE@f-$&S@fK+w6c=HAT9=vtBJ~Igi-8r z0RI@IGK94jd*LkURHbbxAMW zw6m0EATVK@@iX4*0_uG%%ri{ach?@c-vk`aGiuR6&72IM%*bq#22DPT?%omkg~VGn zUX`#e9z$gydlFr2Ea%aTeE1tXfO$nkkSxJ3eAx%MavL@{-0QtRghF~$Ln9D*4$aY- zBih>M(JyvEy$GLz_9*S6%r0u7OwR=2#l4@#ot_V#{wGQ%?g7)wZ0H>HK8r|*hRcpk zTEg&gpG8ZFLG2`cD1|%-+xfx^s!P{NLX9E;z*&lp#3sJSU2~1FU1pd zvJIoBIgXd;*yKN69XZ~%^hCDn$~7=kswD=b&mmhv;MSL-t+okFf|lLWw+`N&41_41S6p(=u zcS&UQVC+d&MSCPr1tloOZEYKYp&*;~rNI3F=O;phuk>vB#6e`Hb}zi20E|>{Uih=lrme`Eof5v<%N^bzWoFG;8}P1h1yLdmKrnAb9vZ*&9>G)L z&g;m^k>oDsVV>3t9>E-}58@MMpC77Gd_XsuYM%%O-@3)9VbB0rf`@3Wz$4OP_PR>o z*T}e4+Qg9+$|@fD`T2s7Pmq(|J|*QoPSj0S zx?7Fg=?6F13UDOWEN}Bzp#$23#%b&!sVhSU21Z7YP^gJvmyp0OnvxxZhyNh$z%oBM zj+~_&5w5TW{NY2ts?nX3S#94ODu%DevIBZ0KI`EH!t zaPR)6L$u0MOv$q+nT7W3*#M!DgwlvpX0kQAbu$0*ueFp9oag&N?coqyPe@4kI;9}T z!V>J+zbmo{KsX%LKjI~`5fqqlaq4T{ms!mJETneDrVW=9@8 zwbQ!PHKnpOJqBJY5#aGeB_vU9xU#Yga{yhE=MG^lNuER@L#l!dT46YRvxLMd06c18 z7e0Z6TVj9j2n`PpH~S!+>?}wjT8d_r_yJ7SLoN6iC%M(t=+s?M{7qO#cyP6HJ@zKM zcklM?R0GuyfmjQD%up)czyu)ziEj#;Fv_H$J3$BSgn#KPxGEZt$(r>Bpnrt~7ctDC z?Nws)D@;s+EMwaf1tMH(fyjsGnM6$ndF@7!a4qXZoDo*u0ir0OwL}VlLTipG?zqRq zpT|Q0e6d;Zc#EwDQBWF@P71}KPjPYaruK<+WDv3NA`{(#lhEj}$qpeINgnh;VJoeiHXQSeA{t6cmGh$EN&k0V zpR^h;j`=cv3N$)H=F{-|UHD#v=ixKM@az;MB*_yt$hU%qJ~5oMt_)A5BHtY&ALIjw z^Mmv$ygiJF_=kqFp{BDaT2+E5ybOh^s5|iMyw|BLlL+HN&-HlLBGPMNE&95!E@ops zfBS~t&sL_W(uZTyXaTOr#)x{e*7b=I;F2A$$bI2U!Z1D4{3Rz&kR(fiJ!lx!k`4pZz5c-a>Q0=##WnEg$&;kgmw2 z5*7?yvA@Q1@;4XKc`N%*!{wnX{weCj@ua_vq1R9}7gh^J`I6U?MA;8v$uY<6ug_4$ zG0TyAv_)E6yYAg^WypG3Xc8oqJ_CA6a)6Zma1EsK)JaTYS1qYA{{lcDo*`hOKaMSV zxTDCFVX9VZ;(cUbB%-4vwhYTrQH0_X62=ziU16G(ynj!Nen}p^g2X0D? z4Hmy~x)g0a%B#zVfW{~aAjFMPdJAZC?AS?|$61LjcaR+EYx7rDGB8-+TI2I+y$`F| zUG)37{iyQsbgB6=Oa&6Jfrp2uxV)SWBrv?t0Okv#0u?rLJ(j32@@`M?STMLB+6jQG z^_X*kR`m!p92Nn2_7@()NgknxbISEhdpyUqVNOY9aVn`C5K9W4FPUP2mVorB*w;Bc zjqE+~J0VJH7q}VWj0d?Ukos0d`G9@E+*(T{b4Zg>62?Igoym+e#Rpjdb7b2vv&Zct z(6sgc;qE=$dhXx<{|K3vki98dWL(IKtVma)VP;hJNGXI;$le-8N+PQyG9pSt8f2tU zDlHOHqN%>O^LoEOzu)J39KXNd_dbr#c)MKndOe?y$2soj{eC~AjwpR^embU_LkH>X z=(VwH$<5#gqK6c^1nE)1rRCFu2)*A{*DaO&DMC)*G>z~=QwxjTBwQiYGJjC4%pTqv zOcM%rYFTW$yX&@ezscrri_z023HadG7PoBJ25{xj{jPX4!X zW42W)b5SO7hYyr59xp=|gL=~izSK9Vf5Zcjfyxf&7mNH4spUr=ge$0F4P&^$kG$wn zI#lQYWVw`{kW07tH|e%i?a?t!-xp&;Ky){h?W_L-R*hKZ)TWo-_+V9bI1TSf5L~#s z%^pqDh~>v~1LAqNKNymL*NKPZoLkFUk)wA4swJ2mR<>I{HM8@xYf|ZMY0)Oqsc(;_ z^nxQvU$6^R-?nM&O6Oz+jSM(kpR>079gzGD;!>wBU2;js{PD+sR2P~N{hKy#emr8W zM>@Git(|`fHUb>tcFD8hiYwG_9WqGA?!K;HL~(w+fsfSRk>nihFE681hU|@FnzQdQ)eH-#7NXQeT2@aAN(b162D*T-@ zyVtt@7eD;a{5<7k*xJF{>$WjguqGydM%lIe9bKf;h~nl~_k=l5wh($Z7&wt3Zfmq7 zAuHn2#kg-5wQZZX@2B%4Z-)!;-jIU!_&Dg3;x-o@91O#t3-yE>lVVa+zs=S^XPc+~ zIX`Hv5gg}|F|(D$Hg}xp%os(SP5!i`)by2@AJK)DS@z$g>bhv=%#1GnH-WDj+v+v{ zr5I=Pp+C*X9xA1P3GU4~y9H;HcJA0=bKO>-vhoPt0SoJkxHq^??8K`_buIS`rj!8y zjffVE>!5&k*c>nK*xxB)|k+dd4Dj!#%Xl~2W>RLZfhUOgj58rvZv~ zqo=evQ<52OI^@TEWvA=o4jB3tr{`j=nK)tW3w-JJ47Prva#Cm5u=dOkDQ|ri`+a!# zuCs#f%)@Z9ViY+S5#_X64UhIHK;k^>W)?`)(k&0Z&F-5d7qzqQ`d(trYZ_h_W|}+V z9{`?4=?S(u{V6~JdQMzf=dM=Q(#~&ES}nY%f~+%gYNB%87`89 z!A~BX9L02|zh7lGb1H<-+9lB+iK9DasDO${X(s0yK)i)HG|yc$)goq;GYU4n9y$R8 z`DVLqk55M9ws_5&gUp~>Z0rWu*OQoS2vp2Kykx10;Al z!si;?3Q!N_?%fA$eyQ`zbqN)Kk0}~zh-~Lfmew_o&Cu4=G+bZB z#H&D)M;))wOLuC&^0~Rc=gm8L;k@3JI;M=cSREem7WkrXvn3v9DhHIlcwvG9rFZsC zjGGz~Di=VYqY7`l+x(wt(=^KKiWom+c=5!AI!A@SpYiBNPHE(uu5Y4_bghC}f<;&l zY5%*HYpjSAPPJG`7UllY-qiKP(W54xvm58^{_J39_c(mRC)H0GnU4@oH+!kpLp$`= ztrOK3s>!<|RuzgS&B5fqGUlf3zZJX^HV%D=8@=4Z2O_h_4Fp1MGvA4%A*W-?5C&9R z58qb|lh}0H<@|`64;0cr_%El5=O3l@6L@En>AcSi=DfYr-t|JHdTM$wUYg~-jOR zfyyYK^Kf#yb?cUsljF+4i;9m|q7j)ySt)@6)xPx37U^^%;slFqB&!Ue zhZu%~OG0I@(Fm+kpH^jgVqZ->ju$&B*6b4ZHW`=#fH#s7j9AJE z*ip=kUIvnUoCW(@>GGq_lXUJm{p`o&XMRr9Em~h0*MT+r!^w=BwdeTPMJgBT2ifIL zg8I4yiJ6F^2x=|<^g0%_zh>U-%{c1?%vjlfy_@X>UkJ9t=i}Lq5Rk` zW&9&?9#f3Mu9J_JqMO8Tcmf3rgz5OoR~w(yjo5edMeMim{@xvum>Yrcw^Yc0IM9DP zC5+%UJdv-M4jnIt2HeC|W<#l|X49%k{-=V2gKcJg%j7_v0wW>^w5^5R4Ztc;dADnx3EM(n>`-2s<`}Lc$KK#<9{)$nU+{B#n`2&Ii zl|uN@p$K7BDn|lHf<~cI*_^~Q=H$O4E}5uxveq6p%u{R^m<`;39h_L}92@EUeOhLm z-nPmpkXArQ8tmsch3F6*{ERSSQx^d%hbalvSf6t04tSa*-iT}iU&l_?UzvlpRM@a- zllPT%lb8ILk6Atx(q3l;v0A(R*DK=`-1#E1=Pc-p$3b^z2JiFhx_I?k{Jdi{3g%fzLk`%)^|o!c zFV7=*PLj8KWvzCJZJ%rOgF_Y8;OgR24qQD1QtjhA)3$w8SWM>4)z09vC`Q33P6uBEqLWoA1zUM3;L^8M3l2pc#E7*w;(nyZZOn<$0>^_ClsSZ@I7Y_L28gf z5cD0NlEk-1?ufm{@%k3Ok~V04&NMjGoBV%K-cWEl!@13g*g!J=2}=-xcT+$+&N+No>k?G?

R(ilh&k_^&YiMcR0$^(DRYBAH{C4P=bU*jrsr3V8SQPE- zIN0g^?c=I?e_b*<6+JB_IoSr#me0)TDH{;LCv1Gjlq(kz#=S25{{6Xv49tFdDK&09 z_&<|S)nD@-UhsV4xCEXVa>z~!!wE=_?+nw>*mo-!#mGrCd#^E{qiM>^$M{d?K@URk zVLi}bG2Dop3*uqC{w!mio77i>>G7{P`_U5I0%D(D*Hc;fsBPrz&zs>p4&_v(UwB6f>A(?Qyx@Nd|WF>~t5a-IF}w+mktPZb+z23GvEPeD3s zw|4xU&`FBJeiuITnvMfT9x(7Bj~|CHm`Y;6far z4gSbAZu-|>QvkA%H_iHdXEnWoymQt)NNUH`yl=mLO?dD^yn(KklAzOepR%&DE2Wx0 zO{^$$IUl7+*I4Z^byjPqLu{U)sgfp)u2y^}79hjqKJD1EXOF7cy^}c6v*JNv_Uzdr zzUAEsvzOrVBR`>?$Ok(PTyV6ujy~=&Z{P-~nv{*-zke6+SAv4PezY=A%gY}HiJFrV zXDY6u=xTJ1$QM10X>1OG{cYIFc7z0>cL|{h&I2wtOCNnaG1}p?D|!Ar1dhuj8rUmZ zPVF3kSmS!EGwx>b5PF!fpFu>0+_eB461>-dvBuKBw0=L(cz!34Qy~UiGSlEqJ*M)3 zBrCu=bT0+yCbcyskvod-pw{6(VwnQ*b3GIxF=M6w0A2IW*IM)lu9j5l7fV{$BRIJ@}8I@rvQwwCT{H z32iP8gcFW$0`s|tr{2b|Z{08=JqKtF1U3;W06}&=_!ofUwD$lM|3JcTL#-?PN;$>Q zEdX(`s%gmW^qf?pEKc*retNBxl$1nJZ(yLr`FO%Jft%YWz1xp*CqOzr9vN3B;&LJ5 zi+w+2k~madrQ{G{!c%lBxRkdb*x`}9pX6#?nj0q^e}N9aywM^KD=}W=YI!BM*8G<> z_PhGmY?&crFJcYuh&8hTfdDI7cMjr5Y6^xzg_BifD%*eFWB?0=#h9jZ=o>1qHADg1 zzT;!~7bb!rGTmwdt{Q>q5ij~N>;R(H9!lR0dVTqxT$?Gz#)mK1_DLPr=Du&`8SI(S{37M>dQ?N=TPavlmT)xW2B|iMzPK{_x-tBjS&%O&2|xEFJ(HLc2Z~w0;kK z27e}u5l#*bWMRh-d=_}svI|Abi@CF+Blq#~+1zF@ZIGJt5$I(i9pm@zq@fWKF{3nD zPK41RV>7qRjH^793;b{PCUivzePCQ&r6OR5 z(CmdQjF3q#r%32v&L-QIdO6gezx-zpz{3P~Q1C4$AwUrfgrCuVRo9Le?u9r$a4g;p zFK`Jh7s%oiE;+D`9H<;_Gu>I8ZRhs;j@IwG*`ri-SMBEgjP>5Y=sH~V`ST#`GG!c} z#zLAqj727Rs+F`GC!{7YVQWdUOdzCWk&+P>pz)6l^pS5N%gDT4>Hw*SyJA#S!@{0N zwks^d?cINU|8U0q`hrteE?*u*Jtn*+I5*>-XN=}IECnmKMklfC$&)8WIQ^SnHQRT4 zZH8Bc*O3S;O6C$)gcx(NP!0`t2rXG!wOANd*z{8m3T+6!d% z(zFby98{q`J~ms9&#%q+RL4IwGBDUPZ0&mwNU8@b3J-7gQ~x|;#dgDK)0U8rmvbqP z_ctMlKi+!&`sHHmm-cq<)M>qMKW&Z3vaWzJhrbd^g2)J1xSW)t1lr;aE@^3^81@TO z>>9QX@Fo_+Vh=*CvIaHeS1x&6>9CPD@8azv8=JH>%nRc+h4eS2Wb*bI#Hluw~ zAB=9^ge>x7Bf%hO{%J@rH?nObVPl8tv5Rnp)QyFIw2JyKonG9^;E79w`w;kjQpEGK=7M9t zCFZ^Z2mR?BGnh{S!R;o;Y$!pBT@o?$@$rnmriVSEO%8$r_dlc+sqDEZEvD!nE!?y< zW@eCX#%ZdRn_Q_O6o|sG!(u!d1(Q}~2mS9G*gul-76*5{5>n^DkDmjz`YpfKsTd$C zltIT$p};}Uu(&cy4d89iZ%pDJZ%L@x9l#*52Dx9YHSk9`RHuCOJl`Q5OCs+JruTUJ0G*LX1#x>JOQfoM@^l8hC=FdOF`yYayFpcQY z0e~u@LhIp2^}1(2Ymy!W6Eo*>?f!j|NW4sE&9WjVCgahE`Dx$myXWcaWY+P7rd%jr zY;9OdZ~2u6G9=xJYFR7{l3C+re-HX3E znf&tQ%vav_j*d$)*l07Rqc7As4(yS+^XCuc-X%Cai;v%f%bjUqiQ17@n_~u#9C=C{ zsoAnL{V=^_-!vHVzu_%6ozfbpLGcVsZ0W2Z5=&P;cg`GawMN__zGs~YD_gid@x1Q! zMV)JIXRWvS8XuZQmR?5RjGFKy6I<|3DwoHFJ3v zaKgF>*@j58jPY>trbVq(NkD~uR&BY`)j5u-sR0L~&%uXCcloNmDW+Wh6Ge_Gj$R?tA^x)FO}2 zrCB4-fH9qu+-~-1GJrb}4c;9MDS(&GQYzc9A0-M!-(U(&zwi5;OGDc!D@QqP-PbhW z&Ti$hTer`ki4;;g=2-t4SC3L-Jw-##pW=tCK5+$`A_tP^#5*V{EVAqx#@rl$`n47A zxfA}(n6P)w%iK0?CqRcmXCetp$D>8nNTV~uThWBn6~q`Ft3OPv{77hxC@1?Y=I4rV zC?LSPbLv6mveEv02yu1r==WuN+X*)M8!!#3dksHL9PSW)=Mp(-eKd+SkJXEcU*tiw zBBa_a;#F}<6-_cz1fEoTJSG34N8|O0FNKMjSuFH9MQp@wUhmgCzoWt0R{uIhX#A5} zue$5$4l%>1cfz4rpJ80${VuAl2{~vWt_5AwpXx8kUZ!{;v6P`g^pHs>igJ8b##7!T z-rqll8sbTW2ll1BlLD1-SKC?{1%VhVB@DI#v|ux2Ae`{UQfI_EKTZ9DEDb` zIz|>?R&9r0(MX|-SyijaQ>OHxqmfc`@;^;XcgG@Q@3#NX9h>equPgkIQ^x;z zY2>H>$G`Z0S!?iD!e6@nYs&J!|8$a!8UFVlh!@APuK%$E`G0&l<5xw1j0l`7n%H|sy8Hw)!zuaP(qw=V< z3mS)AR;~7x8tX1PT3V#7@Y(6%3220_zmK~5<+-${%ekEk9q@p90G#LcS8n_sErGDWb~+%Hig}x*JYF zTRh}hDo&FNcp5z!go2?T&BAR7O`z$RK+*KS_7*f$z90QSIXTVLRtiL+FbKxbP%$i{`aZ7Rv0~(fB47%{3oUcvTRrF7P4|Lwj*#RWs&*Mh19`;Ft>Bq3 z2$#no+(4XTL~%w8gSO;pU955>^HVgXmK6o+jw)$PX|NjR__l`Awfn}6K}M&BG^VAF z_!V12Px=~zp4e$83W0;rW-2LUxnKCtsy!k1Aw-U02zFB&t0!L76OOpx-dN5=J$>4g zk6{4TiK4lGf>!x_Ep_$rxRF3jW>=CcLWI;&bwu-n1+HL%hso*TMm^C z4Oly=sYVO@MnyYEn>Ihq{9{65Vl)}pjvJ|}stSp2?x00XeQd68C(kXxk%bP{{+0Xp z(0#3dfM`>4FB(jqEXJWzr%rX{OY-OyaVrzoTkjEQ1(|-ua7%;*B34S2nX+((uEWUD zyAVDoM%EqVF^KrEfB#ESTi4Lm!RdMoU?lQ$xS;2@*VpZ+=*>6W4NHo>yKRN^LACr} zhb4r=rDc(bOrgF}HI6vSSddfWPq&Dh>8E@u%V;=J4MaGj^=i0s3w}B{Z7Gc0_S2mv zmofkqpo{Wi+_yi*6YL-0Himrd2=}rtLSkU@c#wBy`TZQv@+++PN@8=S@IRAfI^g~% zBL}TL5+Me;T$c0T*!G7~JpQH;l6L?vYrUhRyPKQ-73YmtTI$w$kA;)O)%zdiC0;mU zQ3}x$HDDtgbG5>Id;38u%i}^)%1nYV(WYllo!QPziI2c%)t^7VE5MqtQK=3D2jUNx zFs2kzlRs{c8Y;W|mDPo1*+(nLX!G4zXcGJELv)_eU) z-f4kYY&>%Z&Pz>FX0HLp6C%A&YR6Bt`;3+t#Ora_)JjX?GDICY;B$WAn%b|-k?imQ zI*BWoaJO4{Eke223L=3FQ&U08Z$b6$z!%+n;DDm_!J(JcADL`izh3?I<0qL^)a!9o zYS~Tq=v{4O_zdL4<-7kw^bz9IBhW8sCH{HSya4Q%K;>P0P#oIxyNiaez;aDPS9&0w zf3{8J;3&y9=T@!a(6fqh7n(+e%E3Q`bv5Y8i|dx7&9Ez=U=_bE5Jquo!ji}M`dZTk z@;N|J{-ydH?60bt!-d_E;RLKFqGP0Hyx#cE3J0-r*}Peo`H$MpLr&@KSIyvG)U6L$ z5I%$wwXp&`5-3dFxN5sOxz+^v<3qp$bc(`Q12Yi=Hc){$VuAo(FmQY~#q_SQJqw0} zo)E&6c;|@)E$I5v%pFU`H{&FY3co|P=;G(li|j$yF@@hjHJfM?70R?x7yrGIC2R1&($> ztH4tyPu7~%IOy*WSTqXw5H9KN6JFWF`>;}?ZTIen2#kK~qQoxaBx{m+J-Z-|%Tg3J zoOVo}e^M!$iIe)TCSmmPV8hXvhwC;|;0DG+y%d*5a!lW$|FOlQHWiwgS@1m16WfIt z&4d!h$H#&KcNA^12m#DGnkPsk>HTxlQ^IponOxI*%!F?l!S&9rrPO_keqKDE$V@^c z)x|DPVCRyO)^v^2vY-e_+0Ty@V=gwASPDKX@l1%)_!Dh-o{b_&{HyVG$zcx~-jg{H zq2`2IHRj$bAqbzgkWLB*i+QsyolkuFbUZ7I5Z;!Wj*0pMW-fM=w6Z$ouGaj9m$K@t zCTW+-T|P{n5CyS&zIaiMEAMAp$2igoyhCil_%7lxC14oB=aVCBd)Rgh7_@fAObUX@ zWjVzux;af1gii`qE;;1!PvdF+9WL%F&`T(wmpx3f{77lSHlTmp3WSjZ1}-ck3ak7c z>&IYB2}C8)(VwJyyd2y^oS!&rCv#i_D)Y4q`Kt8vd1M|D2wVz!bpfy4YZ=5;EX5A- zFO_-~Agp16@a%Zd`Q+FcFFj+&zep(`8&{S#%`5xb1>^~JVI%xwZRe_#r_OLOwC`E- zZm-e}>@VGD|7A7takqd)39{M_SRMADeUonJYWI=yID~zP79#8kivg%T{%d;&uzGk@ zstoaTJO0frmv{$6wjK_B))vI{QB<{d+yofIhQ5n47_LdiA+lf9Ez447 zFVRGZ40y0LVuGTsshLYzDchsPS&YXpy^KfH0eCVN>>_bn)JsrOEE8oJk?^#i-=ty? zlMX5=;btvAMLRBTY|Olbz&|H!MH}=ujPhOSB{3^#$_-PgO{SJga9DDtV%rB?%SCN2 zGl2hy(sxU%Nr58c4m_i@olahnJCm9k%b^vi0o~(4(>pJ0KOjvNeJ)pJ+#lCGHu@eV zw%{Mr^!Fuxp-Ge5)pi^Dl|mu%FsDO zK|%cRd7bj4(Mc9)k9Pc`0#!jqG91ngUgjYSyUPwC&-6#mTSb;lp;g#Yx1lQoG@+zX zza}|nX&%;#DVeI;_RrrnVcG-+!x75vNI7zt-Cg``Vr^%2u=m)p52x;IGHG~N@b>Mj zJ2|Oq*N1_UbcARpyc0iN1-)}Pd;Ik2{lxsoJTkgWkIZr6Pr{oDMZq8oFOQF7)J2)L z8>$KJqgvDMV!cb~3S~Fml?wHxpc!#?x%UiW2I1a2?LOuWA}Qf8K7aTcoHd6pNPG^=cXQsfX%mw-MRGiIi`$EB z&gRN0Lx%sgubj0@|J?Jd>d~pnNo7;y6+m+_E0cr+z_m@O8}eQup))qJ<_wa|oBZgz z{}nnZp@5=&F#MeteP?XsXzE;yDHho47@M2lHq)!{zSnRc7zm|;oa6j_oBuk)NS-b^ zbeYLu9u+irq8b_Kh`dLV2l&xr=y5_ZUH~~=Y7KCI!aYVs`zcuLy1b(%$U_=0B^mK+ z4|FLqd-R((?Lndj+V9kMX6L{JP)03zyj?u# zhy-?JL!T(OZM%aSSyB{pa?*tXz`1O51in}w{jA`xHd=pu^Zd_3aw&}|1Q0oa%0RxXH4xrSQByL1gN`jDY%Qe|9K*+|8ck(@dLm8A+Blh)& zsc@sY`BLRzKq+b&?t_pAb*lidZ(x-=Y4T**B|`BptZZ0&`+4>oCb;g6x!LSu<3^3- zGeP?iFRH~hHnKDF@zMTV>o{gJ0)t_@G?7V`*guDagyiSXIbF89M~4zyj%ExQp;6@B zle%lHz7N_!uAyF5680OYAkHJ*={IqK@Q!$_x)J9E{JWKbk0&t01K4D}&~Ci`^-IBgkw#nK2b(6m&$x8N!i*pWA0Q(-*%B?$r<({YOJzkD24UbMX zs`D5o85t$d7?<<;`}f0=<;a$--E>zS%Z20Z%8w++V1-W190VdSFaZb0fRid`NNNU# z3~>u*v!_EOJES0-Dfo77sXs4)NprUhgY&$4f=4FQ`EHFG%oa~B@eX$LriPcWmr9%z zWTGsibFgf={4%45Q?FLm!3f-JJsbHFo&hnogei*0}l3Wz6+il%|)69!4|- zWF*7GYboq#6My!2+r}>m%9T4U3`OVXcTjB_#w@1i?aW`_LygJ-Qi5v~v8}2tBa~S^ zv)VRyZVPO{1i3HrCXWB{plsstv}`Ki8gb8{8r%u-L$R_=XR7ImfqOpK`$fi$Ftfh| zDaM-5zZr%oAcIg1$z!E10Yck=fPOrD2}AwCJUR>;>iYbgITI&QPVfj%jhoN&~`oGT0_ThEx0GNIMx z4hE5b+m6G9$1DxL=wXT{SjugAfIIkX^f6-Gw4tnhpmRIhTa-DyIacD6#VQL~)3jB$ z|A@II`e+23O4CJQh)MF+rW;y@%!TtL24itH~S`MQZ zz5T6APso@`KnPU1vc{AIE{qBv7hCS0oKx7TGUCPcL?nb#V@ko`@Es@gzoFU-(OAJ1 zb1|y@k^Wgp_VYjW=aWilDP8ZiYb_8|GOV_OsW!wtCpjVESfk7(qc`K029U3I z6pK{~pGydmovi=$`hsD+wW(}OlFe8!SJ0z0Bz*uMD^WPK-?n&0kDHk#Fk_@|Y_NXf zw@+DmJOObd=SvF2!Ce*Zjtr2=rk{q!RW#xyDtGF%muh?>`WzyaNRyU6+79t|6n#@? zuH#7I6ac|L499`&vJzw`^^OZ0T_%+8)o>U)cus5Tc0JT2R z4$21U55?n^KCN^M$9^wkD$p6zMgfFy}OcLM=QME(|dEa!=_E~6glG5EM~9_#APoFC+^1f zdYAu!<-zgq-o2Z%VE4_i$KjX4pE{=)YzyhXDf<#|!rNQRTSHh_ylT~cj@|TcF$r59 zZXkjtjD3Wl1$eC^^a!wxb_cap4@58Uc1D@Vz%7=Ga}<4@@HFtwqdtqMnG9Kpq~pEsP}JvbN6iB@b0uF8+GzP(Kj{SR^97gqxk3eZS-=YDQQjk(R>HbN z1*RykgXE}Mbkq0s&*Bju@b(_DE;EWV;-oW#j-WBLIM>^^9o;rq6dt~8c9i$t7cdjF z!{_H62#&(bx;r+7E~Hr~qvJSdMJCZ#u?pThX35AQu_kTU!&_&seGl#qsw zZ@&~6m`g5V2Y{w%V!~(iSt_{&h<68wFTI3nb}?O_^xNnqKjQ2y%`^gJfN|lwNFkcl zE}Hpf9EG{?-|pSJmm|TPzJw?-(k_#+t2@4)A`E3b=nm8rWf#_3WeT04k`X=K=3d0;vtr??bHHR?b?DIa!F}%GRsa;Je zwJ}u6eGs{%&SG9RVe#T~Ha!Q^j}!XrxGS8@sq4%rdt}-urPpYS*d)N z*w~m^&2UT<(wKZhZa&eApenlQvWJ5gvo(V>VTF~nlm znEt(e#q#(rtAjs$XwXDc^N!V$j#mEX9bz)Gdw$JadUEkH^y|C4@>)_=x^rePkEI=WgvM@^z?yZ&&h>vPx+td8xf@ozW6E;vNgS+T#Y)@M+ZfJr zv-eQ%iLOi>_YWRCxSqz6Pep^`r~&0fSTQlzfDe%xd;G~w0~kF4K<{BWbu93 zYeO?H7G-c_ayYb-Y#ZsDY@UMI>R6YWZzCv~&QYzX&%LQC$QUtAJj{!W$-ymPSaeyQ z3Vv!@)r}K$;qOsu{5FDkJa6ds^&g2^UR1wUij~Y7#D$TRFmcVnnQBP9n7+|%#n9}8 zZmn3yFS*?6-2L$I$JyD73d)~4P4><%w4lz9YQB$_3Xq`E3Ci@V z@_Qt6eXPM`dk{skJUcM-v@33t+pYDRZvJqQWy#Eir|BP~k#1ms(L(%=P~` z$FFQHlP!f;-gfFH=f-=bOnrYxpORaq;hLBBvn*GWdNg)S^Ut$q3)2xeS_IQ7nr2gL zGabTP-}j2Ai;u+x+T_M!eeqz2s~{+`LpA1I-%F339>So-4&)E z&W_`VeNs>JETTyJ5CvkvO~J5Rxc`ZXi5YC4Kb85S6nFwjgQk_jJhkSZH+PGLPvlSH zUXlH1tPpgO46u&lBYg!+So58bESda7X$OYmRuCsOcjnRI#WFa=uy!Z)RSa=n{sbkcd7#?3_qW*h zvP%S23?wPG-OV)PWCn^E2GTNc)A8*rNnPx$A z24j9Sx!|B7Q&Os<#L~Hbk_HqELwn{#mKAVJtg5a|e-~iP45Jk z)Lf!WD*&n5vLqcrq1mTn)}*XFyyX3u7R9dyo+(|J`N;W9gZP7}My{xY+A>$g#C|XJ zh{k|dUE8;JFL_k+7xkj@qFac!A)83f6k3_p$Jf*}0x%LyW^e9g%T*`TURs~UotDrq zO{-*v*IO3T^u|+2#6{vbQkr2Ksd2DTQVV`X9M!HA+qJ(lY;%k0M#anbz0v^lzEgS7 zj(X$QHrs2ODY0vzpgW=`Abd?9mcWZj~Io}lR|=F|*E zI%tJ*773vjX6Ze9drSCi68yO7o`H={%(#T4Vztnbw>?C>DRKvSS=7$z<=1Cc=(+h+ zxlp4TeQkC!^d>actu4F!-L$^;XKPybOY1*Sw^{)xt?D|DtF;@MjmE=;fkvGLdi;2s zYi3O;+PoqGOIwnc4DJcxtrUvJn&VjR$`jH^Hn)s)ac16av~*M|8`a|o7IAwFokFTP zC>r^P_)M}ciSBO?6ST_voOiKpCMjr?k4b7#cBLiS<@F?;HUC)F5akJd%@v!pe4$&G z%cGWMIh66&Ow#0O!7YtuN-Oe%dk-E=>~~@bOS)xJh<0ay#gI{u=xO$IDYr&l#1pC` zuXr-AxWR?(S!KJA9&8{nQyfm!O#@rX2BRiF@vRNH+5fa&!LEo3yHKk*h1!a<@1cf6 zFO)s8r4P{0nu`HAz;&CRss$mqD0Y%zBE$V&O&hMuesqYmcOU~q;|Lm~sTXHS#zveE zJ^qey#Q+^~IH3~mur-%w-<7v%u{rXK@WZ%XlSm$}>_TH$u84k6(6up)RA%F%=?5eR1Kk_Yb>3OYd1(v0G4+Fi@7|^MyyVi?GsZ2U*SB&N6Le7*OFa1H zkW#jM%19@JNMP{F%F4Zun>g9h8h_F{xUS~=gA#%7Wyu+@rQy-6d!m0KZdua~k>YmwtK&ZyP(3xYc@%-F)_g% zvf-nw)#gx;LpdT~xYk^8tAHi6ii7O)pK2dbdX>~-`}Xa^2C1*uUfM~|&9)Uv z8v5Q@D~__>Gx*glFWw8Q72Sww`}z)4kfP483hi!TWr6o1L@2)9R;IiRd?;4G^hYN1 zJC4O+8Vq$eYGsX}&TOB_wk9|J+q>h3;$Be(L!NeO=Esquijx^OD9xoG5!VwbS3slA z586LmM^I;!V?v5xvLc2=WJg(%)(Owa=$d0k7FfL5-DmqZ`pD)rNy`@{7(H4x7gHDC zM%@aEX5F?pU3(jr*?Ed2BnkRfR_EXUeo6bDN5d@|1y4`UCpA{A^M*EK2eSr)0g$rQ zEz=?~{(>Mqo{o)h=a)yZ4uV6gzL>|7yF3-TxZu?2^pzD4-}iV-2OUdXGP|R;%=k5T z%JWNyt1ISbk7T%qdP__&+_N(e4pskb=)L6~H>{M&Iq_%Gl4jz;ul_6$97NK^zp>qG zK)avDLZr*f3-CBxyCOyWP}nSS)$9>A(`|6Edd!dbZ!3O|h%)ESn$8&9_^Yg^R??*u zj63%6eqCVFl`ZRzC)j}e?_pptE9xcnC6~`^j!VD(E2#8{{)D%V(?RgB(vg$(9-=yJAs~EPHNzZG4}o}ur3K;Wn$J98{pLm%6lbzFu@AA5 zBDH!|VeLjcyP2JLCCzcA$fe0>z&elIsr#9xXqsmgROkeV(w;1aCcBNI9Wh0;GGv_K z*#XDKn4EPgHB=*)J-#O}Scb=DyQf^eE!hL*|4{WT-~Zdf?Ns5L2&k?+u)tnGd%(N^iM# z%Mb(;wyZsyM5w0Y8r7MqNsOg9o@YVU^JKUZe;a=Z^3PwA)WDSP+53l#)H6B zs+XBv%^8`pwk+{GDMFfs+ue2G!Srt0kist5Ba4?HQbg&d@fj_JLI8EdpNP=>P*&}^JZ)b(d7E|*{cEam~lsQ^doD_T8J$I(3z*3zIk&M zbp?JHz@^Z$<{H!nZ=q7N#Owo zd?Lt1U1hOq((}r>Nj5{+XAqN;GGp7um#LLKaHNt31$@KMw9iJwwjgWKQt<8%rtiaM z^iose}D7Vq4yLE4#viVYd1_!Fud)tlaTSSibJyS>}QI3F9K zTzWP$C%>VAz4LCtj&4gA7Q6f%zcCk4_m&x#(f}dcpUm&|r(zt_OprZU!NoZOOW(!4 zQV)Lc4!+Z1-GD>j+}fjy5PnQu+SjnKHk{g9>MrK?CiDAUehub_Y^5z1$Bk!B=-I=w zU0yZz3a1lo_f8TjRGBV}&)O~@H?w`Yr&2hU2)B$(*Nqwaxw_`Emh+dnR3>Jdum8=$ zLuP~P0b|@pE!g-63wUWoC1Zv9Pf>2fC@O%)Mt$b_U5TGgQ$QW){#QZD_fzTJZr{Bt z&8XSNZ7)5Cr(9eVz_(D}_8jy~bdo~505Y{3uH>9}mk(GRrn2|WTIII;8u2sYWh@Kq z7J9fXZv|h#fI|99&rZ##Ik(zr8|QUfb$K&JZ3YA@{DI=uaxW)R#fp(e z=&_*jZViQQ$gCu6Rr2%CfL%uuZ20X5CcIPa`EU%T7D5hWl}keU*5E{;7qUPHm`1%( zBTDsB65#TYmLq;FvaFXES81F9lbxboM6CEAs01GRd&G#@t3ErzE#Vm{tEgNs$v*a} zjlOd1wgsRPDmn|t#3vs3dS{bDQFm73eU&w17l+8M6G5XP|Af3)XhLV`p407UmOeDU zL$c71vwg92)50wE;3SoR_eRIgcqW?;-{R!7PydIQX>!@xM7Ag^@xA(~Wz}lX!d&@Z ztGxA(dT;Msk@8Pj@ny|TLoZp@>AmrMdH=hujg80Z5>wxC!w$JzH}T)mX^n@7LWm^? z8CR)q^946iUl?&O+Nkq=ugC9kBJ@ftxOL}OcO;rrodVZWRJ^JRA2rmncqxSf%)Ibd zW+wnxCsOFDd@tfQfwL)*7%coPXZ1ODj>5`t=;nNbt!ac@V1`{>XHCo3NlDY!w6ljq z&3+P@LQ=#49~-ZUcV?l4FT-^@RO2A)&u9G1#*RS8CTdb?f6ykyr4*sXl))$%uy0=R zp4FyaHrr<0_Za)-X%Zk!2$QFYAq7jzA4)sCI%D9OV@T>O9Fu&?%o}xs4|kZCo68-* z25R=(z*U;V@Kn6Ga&-RdU6;F949FHel-^M_ifn>gZkO(-610Akj0_vQfn~osa^%2) z19OU2V!#G>RA*qW&xuw?=3ZP*xPJR^>)lRn-h6@{7?Hb@s~1zQ8Gd>;rj_&!GcTL; zmaS1MOAB>Q1zy>JKd3B`V@RkTZ08x?TghbXucNh(_41YrphS%;=6~vQZuFAspRtlZ zrI_KwOXf}>7DT^uH#hAbGlR5i+&kxTk@K>O8rkmp)MkihU&A3dn9l0t?C~aSG6=p| z;_@#m?pJhHtYRt!pknfFXNXP0@llPtnOnX_ruFfb3o}Q0uhldN7t6-HAG3?0mj1Vx zfUC#OjhtEf(hINJzw+4$2B8Yd>9$F_o;Q51zh3{e zYpHx_fzd)_dxDUY=|kh1Y_D=Ox19LFal?6^r(bHk!iO4?ZCHJbI68PnLpo>s*E=P{uTXMj|)pH^cWJsUijl`RmbhD_6}7 zC+mm4{KjK4%rN(y`+k69g2)4y1HjrgnR_QmF{$kK%ncK$WNxkbyp4k-Y$s}; zFT;|i^7*6@l{AjVn-ay)G~$ZY+O@$)4GLNv;PuS>oZ7#C(}oIfe)%NZ=bh5`16N)# zKP6H9+VV)Brn(5J8sL{P9!R16qYzMWY1curpE&MB^K25a*)DWJLU4`Jii!o*!(KV0 z=O_4fi8m$bpwbZwKfBT`}yJZen&$u-)sqHS) zT*+E^iFcm7Uqzo1UqBfAiv}hA)%mo`o}p=u1!ctxR~tUKT66WwF!TJS=K?GI2A3!n z6<@w`KLk7qLYPn^;>t?NHG+Xl@J)?6ZQ9r8Rz_BX3!AzCU8c@?vg~=@mqb$>=>*o} z`*blVOBij8pYQDAqNJgB!t*cmnY{lTWxtG!`58KTL!T-hxf&Sh5$P4aWR#lPKH5Wj?yRMMMsJ9k>9Wqe zLf_o!Y|n3V2~^AnXP*9RXu!dObAU`TxdWTtdn-`NV+FIi9$!xTr51Fqy1x&np7#HC z??-$fu>eF|Y);>Oo z7Fc$th)V?xBvS9);HWf$GGet$yh`7vY0l1ppQ~np;)o+HWQUPm)~9;-GEI_o^%B{d z4sTdA@dQP|LH?%~s|KzqO#~Yv%r@Y%SFUP!_{$cD68mwobXr;mrY(XR zg6OsnR6p_h@2f+t7ha_v5}`XPL}E`1+4D)Y7b<4cb8d_;PTV{2?_Ud2zE@0i(E(1m z6k7WC&!>BQy18BYR-xN}<!rdv-T> zilj5aAg&uQO0=4O*msn3RB%&W!dDLQ)4hOgMd@XY+ zBEx|I=OH0yFCwvODQqPgXIY*^lKyx9E1h|_sy{FLnS^5M!86x7)MEq|4j}4M^8rR0 z7F7(pg}mhKmdr0hoFDjXt4;+77mHyA$B!&pX)tbmObsTp66*YhRKrh%p9$=v>~xPPt785F)$sBk%i9ceSv|6zZ3NisDFY6q zYPs61T)7WGEeGBtdx;LcT47$b^;xaOJcl!*#q9E9WiqPH0u_WqUt$vcuY)>hr?X5sL4SZwDZVUq%eT5T7vM zN9H-gpZWU@w|u;}EAi=r?VqWRw~IB0sVHl~*2=O6}(Tnz*U13h+|ULW+n&uIJe4}pvyqa;}JB|6#ue#}CWLfo$pj;E}~ z-hHY#v$Mgur^OqX5E+&%(i?5`WOR2>4OfqmNH1GcM`SZW%5aDhG%=+{* z*|Z023&dY$OSZf{NENmOkxAP2p;_L^G)Qe2I}BNxnE!Tb!qkPnE4K`MliK9m_lnWh zPQj}?2OY{(&pg%wp>4=)-TLM7>C)Ndy&kVK7z1w-`W zVXVZweA$MiD707@`s-m8d1#s&`fs@UsA}4W8m}2=wupTOc=4<&Dk58BFn>LD=- zgxFub`PKMSfH52~31(8;2AxtG?{C-8#4;g0u>T>WeiIE1_v;qSG=Ayf`}y+<^1$Vw zSjGBbO)szNWFpgB1^ci~TU1@c0$%FO^O`%N2VZ6{SG=KkrD7L>1m zw&Q5_e3YLa>J*CPY^nr7765={^K9 zk6RCcmAhb4aQ;{4gsnx_s*N0~{sM{U1A+|VF`1y_aVMJ>C&d;oPPWOG-p7B&?DWD( zZyYx~@Og28Mp^`DG9{qP)offO2cLiWBUza8METPEKW$~1qCLci1^gAY zjV`YOFMZ^Sg)WI{f_p{>VM>u~fSeCfI<`Px{_>0es0W}ZjmA3~C|UZ=9P<5)la1k& zvg{e=*?FiD+%;CYmfP1c{GT`D;~Z9^GaYQg?T#iX4W@NG>vD?>vzJ&DYOh+dvYHvJ zd(oRATN#V`j&a%{ZR9`g1Cu`2eZ3JH>L!QxCaY!|Q;)AQ@Lg62-5!+S@|y9l-`#md-Ec>w{2%Y&5N1&^)H?ilct?kYV}{zO>1-WQ z`~}S4u~okX*bq`0yA@$`V7>8Sd-##J zCar9)?_@dd=iWkS-=2|sPDN@lIh+EyL#CDvtud0j^=s?tQ^UV?Kb`0+

NT80NG-bvM}h(3;4LnS!HWv#>V*3m8|O0^wuVvc0BAE%kwO3} z+~^Fo9O{109ZO!zU*F%=U9|IqQ5o?%hdrE}oKS}hW&cCZLE(jaq;ZntiNMRrSs{Ot)-DvVITxc=7rfCp$1^r zMvg&|5>IIsfPISbhSm13`0lZ$ijuGhzk8)8r7reXMJs@^3tDRMp_}u}-UT5YJVZ|q z1-Z(XM%Q-XrzPGsR9F|P;?ON8z~CJ#Qd&xH*b4iZpht9vc5G|`4o`^^RteHNX!q}f zcacl~0yNs2ut-G1H1tXoP!f=@hX}yPh;oFCSfQT`5E70WGFi%omzV+gpwQ^Itb8P@ zQCt_ueRWr5XXE@>`0RdkC0#k`!v5Y|{<}ZS$(lOm>M0@;b8f@2viv~M#s8wz!^Yt1 z=}mTi16u@7Ey9gOo2S|5swOCvJQuK8J(*W%2X8!R`H z1$O7`R9WKqd&@D^Ddy^%6GMB>GWxlV(z}$;<5}fWB;IHd1AuY*;14|V?HrVHF~^3M zuv^1(?EOD9o!@~u!mY|LO*5w{*b|fqySiq*Q?eZ`>!-it+ieH1wk0@?e|si-R>FJa zzmL#9eA{0zSdbbid0pi{<1`EDcYl8%8+2E1urNx%z~D=rUJgh_JnNNofRHzN{+jd4 zq7+DMOr~1Sbf53f*e@CBuQ|@wSj59s!SAXK@aEgmznAOkaPCC|{n1=}jG}HyHRi|* z9Ik35Zg}rbn#F7545UYpz;A2|pj5hSgFE$qjEh-=(Tojk$uL@z^(w>*+_&9FlwGJc zsE=+F-iSXRHOc#xD%h16OUe^#l$yIO@q~5LbXU{9SHe}$H{9tewh}bbkxV1Qp<$G$ zuEzTJE8%crY-oPGzmTpBDY}-W0Ir~8o5Akr|1buD9FYj~{@y2SW+eBwdRh^((b77d6864oQp z+H24-B*OE!e?6HGT}vjVJd-PMvUn zrOXWl_KE%sU+P?+scj%7=H%I%&h>9k-V~6b5s1i7{Kq_T37oC5!uaJW8oP3q6Ea!u+s@j~~exVe))Ivy2sbW9nNvoW3fHB2*B zx+1T6W-R2-OB;hL@ON93I@c2bEzRY6C6de4N_H96{k*+QOG{%5_qiD0w=gZ+5Q851 zRNbaXn)C4s#>f%1Uicx5Ki|rkk%PI6Z)38m@ix$-8EHf$Pcfe*O&)vYv zU{uQjkL&axv#yGI?n}!%x*sQ>L@Hhf&(_AuwWyx;e1orx9UFO_C$d>%M{-eb{k3BKT;DSC^O1aG^ z_c&joGb#eRIrQoJq<`L#*I~DZ=Wwd+s+H1JMA#=Zt3+^PNid(nNE9%a zQNhR(U)PWH9i%Ab>)M)q6??ZQe|V?mR?3E$3h!p_$^ZUKL!)HUWS7aIgIx%8XZ4Yr zi>+JYY7S$pyrgtj9h1{kjOG1OzO<3js%FLn#^TcS)OC!x{3$q3rZgDBT;e7dTi>~@qZY*z ztu_omuYBzv2`XOaANBG3>K&&)D^7>{eOS|=RucLJB6lGM)=CmAsmX`!Y>`X`U-f=A zOn>)r3OF^qx0)g~_mQJQ{%Lbg@8V=m@eAI27}|<8KJH?x$0F&k?wcmvmXrFiHiOC6 zJ1{9ohrMz+WZx=g*xH@_iY4W*;22f1-f$vgq0jb581pOG3k;+#0@caeP6GkNV)}HtG9}Ad(VchUK>{5PN#9cQwMh6hGUJ+_wv3&^C_9fd(droW9 zs#z=0{Q)lO3k`IPj=v`u;KLH?1O=;#H>lcR^XaP~vV{w#6s;>WNb!v)sUB0f7ONjSFcm3!*Xg#P9g4DRZ32b;DLNYB1<|EX5PoghV&?2wYPySDpucK zlBw0f+;-LZhKizWig(IV+wHA|s#c9-6yB^8`o}P#H@|xqs^z2`P*T}k;(*5!J*5~i zHMZp`UjeP0+FOQpw;NX!nwWxhg5jqy}k!u|o8?T|&ivc$&fRb;01=Z^#SbHMhC!kF5QlTZ5=l0aaf&=kH z7CGaXV(@fDK2abEN8qN&-s&COR)o%iW0^KcSV68NQVBUtNSp$lFY?|^YGGlEi35dp zG7JE)!_rM0znkx|Pm+s}dheYzmngl_J_&$H_;#ejVi@%YTP6EAH`j2PD zLy7kHi_5Nwi^P4}_o{t1ZRf}$4-ne2t^!VQ>_Ot;LrNkl0oPWtxx+(uSUiQTa8^c42NMi`7UyW-!Rs6rSTjSg+TD2zv>XW)s_eD2fF}!y%B^ zwsx8u1;Hqos{FQQ9hNsYV3U^TyflnW7yj(|MRP1E9ko*`>iGq=MlW6*@B4kOq^xGA zbXG)t`U$S#?;?e*>cHD^L|1f`ncGAvG6y!zybSWMet6Z&zx_E_J7oGBV}^Af2D4dY z9a5bvkq;bdW~vNUKCsaGxzko~)4{nbm$al53&Y*EIH94PVV9YV>0wmGguXwygDSlgmTT(nM`;e@;AO;5mN&%SmKdMj3B?hV7wYk269e*mBj zsB9PtucXx%><~Zl25<%p<-VtuXXnA|G6Fz9>{yN8wS_;EJr<5=OU`#4wuq8iD+#(v z)RRJ#=;{Jf)>*PLZ-3OwSPJ{Bc=HxRbWMt#QGoreZGl7szNc(=EVd-d*%!7RNq9zQ zOujuqp47&2oiiYV&_9|o^@=%m{eD8%gm(SScP#>glbR{`7%{WSkqkA7NPAmMDIf4% z>$Q4(SQQj7TW2;a^P?TbRb7mtA?k%)p&}oKhrhd- z+c^XsyUoT3VZDGO$shN#n39&*>r%i;1lV!}w%tA;;3>`podYQJ_=qQ|;-J7Fw=xu( zgA)G$WPE~0e2Si#xqf&T77s%1o9&}j4^udsa{yijF*!DPA|tP63<4|wE<{s2QcHof zU!%%Xc)T?s0B{SWE-e+Ahs8N^y)FoQhgk$P#sIX6XzzsnPk>YCgRW5kaK(2a<~9?p zXOVUqRv>lAl>oVqB0Iw?=20zssLp|_YcgoonbJ&xu0U2!oF(jOTy#mvD7=`2X*ypS z3s7$~bG5!(;hP?Lkei=Let?M^u^M|kXQ!?H`nH^eS}SG0?%MQAoeTQ+$ZZU_1ftUb z$=^JGZIHR%aB=24VSmd@EI}ic&JbGr@nvSZz#V>S>>%B>863>2Y?=&xfOhmbxvmOv z<~}>tlzN7qy1Wvuh^k0uH+1U>XlaKR(%@B`wC9Lt=uZ3SPF24(FK!l*%v!mjl9l}N zi)Z`|Um#{tr)d#KGVMes9Gcj?lgcw9hAw&3tyu}FOJTn@2zoOd%r5BSQT#Q;rC8aE z!ny)kIRJQ43%jsqkeh%UP1HaEB}xZnnP+PO8A->P5ZpqpIcQ}D-pD9}2yVmk*Z{N# z=yzYa4N}=;xGCsmL6Nv{XHbTai)%m1>veW6shtDV^E0I6J_De07({~z$V1kf@xpN% zl_%kNs6`PBfU(_&^BCEKq0QSR?UMF!Pq5wLM!w!BPZ}z_{)oB%dt%G?iqc#CoYYL) zMOn2CQ$;yyGOXf!5bjV(1%s~@<&-NlBJJabiFZeSrz^am?WCYGVbfwmXYg24fGDZ4 zeF?`{)q%Lj#WLJH>!-(nFIjz?gA(U!?7_bI4~a0ehca7wf1k-i&R$P{5C&YWphad2STI}_w; z(-`QgUJzo?abJr{@OI|1DIvi+$|uNwwK;S?snUsaV8?X6m*b)n=eVfz;@l(boQCtf zC-1XGQ9jUn4D%dZkyFUIAGDl-uq{l2#XZ{IXI zV?z;)fUzxc;N%T*WzK#lFpff|`*07R^@4Iy3H+Uqz8!gB1FVji?g&_gZ8U%wZEG&@ zF)*MD;Om9gsS5d0!`|_7gYp!lF9V8V%N3%oq@JnLkR5FV`#LVzlYrFwb6XpyLWqG( zPdsNf9VnXy>*!w(OAQ+^4#tEJF+018YgK$K(K(L0$!mSAmd;CTWe_X>ogP0GCrO1` z!05p(-`1}K=XzEW=x2VVNe53f-r~76qnQ8-8bgO^dXEFo3@CB8WqjhQs?JAAZ{4f6 z!o_S63xDYqVSg#%vVJK^?fl57=%QukFRG>#cSW5_hJHR=8l$yWYM&Gn1#y0sXavW% z6?z$b>FgC0Re<4t6p!Yeyu;o{jcpyXhy=4=y_-UYAIx6P$-S#7>3(C|g+CJT8y))#)&k=>kR>#EJ=ae>jaD{)oOB+-QKdq0gfF7zN87MxH|O*XC%% z=~-}VN1ICYzjIcsl5dZm%lfJeYN+^pXZR<^23v#zBE8*Y-WsphL2TtqBrt6!+oQqO zD<#Po4&-e|8M6wlqpDz}YMP^UT<8|3_?lIRHc3({ouINIR^4~VWyzFyiou90YHl8- zJ2o~itf*H>;-N*2sV-1u;|lmh4eQCV^{UW3(-Va<8$)U$r~5s&#eRH{d$;qEC{1a- zY)N%L3dAOuq;eZFk()-=4O^>P>?gf;3F*t8BhH$Jph8kv4!JuTX z^VNr>Rkvk=qwMXTFLyRT6q6lHtjyB)nE1y?W|Fz<55!vlZ)vpfgb)VLcVMkM2>Gwo zfYuN4qj+f~b%y0Ir5SrE9#cVnKIFoR2cG5l77{peF{z>9*V@dk z&;{;Tf$`s|T!Zt+9`HKO_{-jqN>^cEn^=e_G``Jsnz5wxQN0bP3Oo`)JJz!1U2vpha!zHgKv^_m zqAcmkHnsDKgP)^mc!66X;n8%X>ZA&*bgHrm`-d+8Fvs7#udN&@jN;#B-z25U-QlzT z3WOVtv!r_#_z)q(F_`)M=DmCLt{d|~aGs~Yi3QfXFIu*sML_cmz=ex&Q$GPi6?ut+ z3AHakoFMDbec!7^C+v#*LF;IanmON_jt95)t>wTSc`4VoU$`sNNH50)tbijpt+#7$ z=oXwIW-DeZ%pF@5duyJeewr*&h&(0W61bp1X`hKXLL$Ui8rZ;XuGm%Qan8}+Twk- zS52Y6_Lll!c4|Y_>Y!>q*5r6Lpr*uB_;p0qf3X+g-aW z{H_VdoJ14veOu7-m^sBAd{*LDVgED%S zF-adE0{Nek$yk8sQNu!tB-3Nwa%$Mrh`S77uU^&3D{0T*Ro1D(@=!a4>1xu3fh~vmgLi?oAblUSc)==_t~z7Vp9??caqW*@tR9T=I! zfV^5q=XH{23%Ld(s|;ucpf&l>d-d_-$4YN*QSHLP(L6oOwjCT&G>$XEYkfAqiO*U&F!5dFZG_l zj-lK>pT>cOg~_Oegyecc8J&XIkB0E}fr_Wu;TK=foC&D2t~9;EOn`aRDrC@Gw$Qo% zA+~DI*kJ$$r|jS{qdtllTz7jk-_u62UJpBw%H3bg3o9_`o^mP}Sn9go{XP!!OR7bG zdj-UiW7y$+WG5!jM8%ms{dQ6nQKovmT1o*bd zlKVI%r3rVX!;BihYbRk-d`y#+tQ5Zqt1h0?8`!Jccd2+ddSh)zpY@dFR0nN0YXTli*b;HIWnZvJef$jI|J^I7ek zH(1eeoUmxl0zJeKv}0ljp)&n;pgS;T+7z2_+?=Mq%*{p1S!u`6gGyEgaFJ&Ce&_1XL*pn7xK`@&-8)Csy;pU-eexbH9iVqI7gqD`lU$K64|ibfygdHzm%r@ zt|~H2ADfD|B*ziV(tx;~@YhfGM)re-i++9eVKv?_V@*r0dm|b4+iivF+oBDDphpM~ z?9|9_v7Jv>GNfMni}o}=M!7Z-#4R)B;2~eY{iaO|qGG|o?NWa(0LzKNk=-t+- zk4i}pasgZENgw3GR;l!g5cIB)1Hb?>0mJ=*Sii%_$^uKslPYu;k_$04J%C_5@Zakg zHbAuA48>$@G?fL6NUGRT9RZKXDMg*CV-?DT)sBt{QZ_EmeoxQOY&#gn3qW;^l39;F8BlThcTLv0nNgXnkU1E2e7cRS{Su>1%N$LtZJeaG zmqr#By4bn{{BIQqw$_x%7`SnCrBU!nbezACD*Yq*&fCzB21hW}Ioij|UlnM-)M1GH zoF#L@5Q1(1H*1G@la1KEmW~ue~U8;`|QC=E@{l9~h;4 zHPL@%aRz{DL)Ck6Mioh34K6jumd8$AsJuA00eB~xut`+S0jnSgtT#Q{h-K|>wGKu$8=Ze6ACjxDni%@ zCsafT$;iupMFuLc+B~UyN@S8useuWmjm7%?=5&R&%X4%OKp)H{p{a`6%r>Ex$qw@m z?QE7NFM9OhDw2_5g~VSm+OsEmx#=8kJC{;d|Gwjl@{bS1hB}tfn~tYOxRdnp6K_6t zJHcdZ-8p43&v~VB(Rp!9ok^HjF0POJT zma)I}UQ~+QH@Jo^kB?xQVdAe3+!08;ACeQi<4<7+ zTiFQ;v8uNwy6wAHa)>e3DGpQ~)iGnmG>xeo{j%|zI!Go;kb(pNWL~1fIg)|(J@DRP z(0)CGWuyQWmWb8~)^-UMCmHO>C8AOQ1x;OzvHcOx*v}r=lxe(?*lA=m)G9@HJv_qD zUXPS5P28TC$aTX*sXSq+*Lv6zBGhX8O1;IgFpCzya1(F8#N1T>yc0}oxsg8tZ3p1O z=l8pYN`8H~n~398)ad<-nulhkQH zv{tskGLAGm@T8~$h=K-($L#7iF5~7EhR{uHCqAK3xRm}}2%qMYNbn6GC0aXb@Ntk) z)nie(afqS54wGG6!dF%Xtv^Vv`1u3&9tOJbyJ;~jcaR3W@SME-@OeZ!6(mI6J2we~ zC*J;@bny=c7rFT?AMs6z^fex-f#$9LqhlV`$q0vel3H~=iB~Ut0@=N)t!Fg|5)&2Y z`v{HX1x;maE@MwLBah*!8GbP?Yg!Eb%pw`r$R}D7&F5b~U5CaPZOcsE)?Wj0N$u5L zD9wK|AdEdFRMbMmqchC5iZjhuvr_c(aS5A`Mshjr&j!4l%R z(UYc1rIMt=GmKILNTBcTanllK`f z%d{q!-*fkg%kduR3>W|~8;>}cQodo6GWIzm%y1y?oVPlhho8+F0akixMM6^EsNUxL zn>7E5;M6N-KWD>Jx~FWs#JeeYYpiff_SilB+(qx1 zJvNiVqc0cBs$Ld4u-TSPw}09#zqF+mbE&ouJ&JBX#)s1Q~XE3wI zDDLlY4h)o@{vCc1O*a4|}>%;gyC_}7wbv&ISg#iRqwWLS7XH|$_iTjByn=980cauap<6dZ;xJQQqe zqa`}A*_1<)k+bv}mrafuv|~r~JXrcSNJyVbMY1jiKNp91ic7|QTMqNs*O_oIsk^!* z%1OS0M@WQ~uzohTKV(c}`DsZYKtxLmy=1}Fwk%IaIw?z8qVI%+#{j+#8sBzUDoT&s z4+2(~?2u*-5KA1U>|PuRPSEr#xqvQ)Gc|71FCqbIyu84jJK^SvV-YHApysf`!cvxyK?ros-q3#(MFxcR#dNCaayF*|aUok@HwcoU=XSQw7k!#k{$B`DIjk z_NuQN(zMQBWX(|}9*W~n{<>~SaJLkDzryVbUkMUts5cdLb+K{%&rkf5m|i$esr8@0D@q;;+^A#?FpRD)V+*KPxsOi`ubyJ|bwO z>Hy)11EqHj%`ZostK3<_g$xO%D0)Z9HXRe%GUm5|`SM0De#ljoMYA0#w!7Z56Yo?( zQWti7mz}w#=4rc-GX@j-i#gqlUKcmLBgQc9&OSmv&;Q4DKvLzkOu0vM5!FU>-j`S( ze3lSWc1ff zm-t4hAU6|rFeyPGXhddze&M%Vc|yLay-)YOEE2AJJ2atP!A~w+3{6x5Xkq+Rvm>Wo z|8=%J&b8lV`i_N##vDGG+eg3jHu4<&a?p8~DuAoN%b$o(V`&6jjPJ~$X{?{~Gg~i7 zXe-OovmN!g^85FYX2r*rQon%o?@kNiad!wmQA9)xT3;Xg5T7Yfmoe?PJgZrj{k9Md zBjRs=Fvx}PJ-y8Vbe=;hXT}AdCYuOpi}f4W@o}ord8T(NzY2p9Dg6dc=jgk$)-)WAMC$6n0F{_%!=8y4n!vGw_xqPaRdyi&Y^HBC*%;Q;%Q zK!21n!1LN8wUN}K=&ZnYynvO?5A+Lv{F;lJw zOUl$);6h>|-$Y!5D!Uk+lpEeRIoZYHq=<Icr{Tdte)9RP z56*Yz|HAl^l8hj(lrmkjfzoF&Khm6i)$-e(dl2#bd}vWUhx%K$#p5p-IO-xjH!qW? zdzry6QvC{cjLV!ecBk0rEyfzPo4lz{Tp;JVJav`=?{uUa3AJctyD-mDy%G@@6B^p1 z)bY8T+GiuzQlax(ZQ&o@Tc-2Cy?%VTylQX=3%08HbZm@xraK!GGn;73aPjsjwjV7|{VAwtUhN?-br>Z?qyPR9mMR+gfSQ z>f)Oa%@Qyq63K#R2TY@4h9)(ykceCPu%>t$C>aY}Ex0Ns!|an?X@3MpS4@xQP$?G& zX>PNLrm7qMH#zae;>&?UzebOIQRO|A;zIUeKcYpkq#v6*lkr*2NXSzj`U zZYJk>1K`@Pr=(bzN5%X?nJRuxDkmE&1=WNO=-9Vu8xDFWzo?$l^@O$RVB{S&)m&x& z3c8ioCFmNy|8znMSJQ!%?)s~j_?XZi?{!X5)=wR%7hS@nkSE?qV$Q@lxw^(mxaS#E za5C+~@)?_Ey`;uOpH)5}Vw|_AX{;^B#9|O0lejbt0bI0;9W(woC5G!@+Z@m&% z_p-0jW^JyW7S;2!oRgK+Jh>lySDG&G8{VdYU1DQ@L(9?YBLgkn_PmQ`eFSJB+Ks@p~c#?i9c!^B54xGAOOkZC!M(C0cT< z|9(IZi{_=uLai%(3880qD=AMrt^*AcID)*dI&niL_JmDD{g}(Powiyo$9uk=^_*_0 z>(pc~(kHZk4#a-0pWA3T1(md@&JtIF55E%2zmjRv+z)yo}8E@{^+7Z#}6FU{!|oZ|qQ=Wxm`lf>ehY1{ka` zLHOMt9Uq%$XuR0Hs~l71QLE(ddK2XKpF!J)&}}de1V0)ez4rD^fN%xzJV0avY6db$ zTt;|mLPwmf0q_cWV8uEJlv&UMqKJiymhJ>w02iVBHhM_%gam99$m+Q@Mwzv;@Vo3O zkj&45Dg==B4Arbt;EhY@AkuWc&b(ZA76dJovK2j@K*fbr zz|8@ON*bMIjMMat5tq-W zzU*Ob%cl=Q?$mcnWC6?to z{n=MPJkR7-bfj-fOBJqFM<`fGa09q+N5YLu^ z*7(RJ8-kW00f^Ia_NR|eoAw(2|26zLs4*Q=jp#@J|I8scR??L2$FAELA}gsE!fC0g z0ubZ4YhOu|_AL!G70y#>(TWqX%FjUn`(d+R7-V0t&*AN@1h5EYC<8Erglr&9;4*2Y zIgGLee?Jp0ARQq74uF6NK0^7UAha4cZ8;3u08nBC{vom$0c~&ytdlens%%GYHpdH| z0HKuct5@^cu z-8>8O%~v%ygt^`Fn@33`sh`s;kx+*@eTSBOzMSO!4Lz~kr_+n~lPvFsFZFRtmA~XK z-Zpv%=Thq4TC_~j<1bCbK9BK-I)-Zj-=TM@!H@75XE+5C1^t04Z%+^6c>m38@5u*T z0!tUH*hkxNACp)Jo7~sRZb^KC33VC9Ri>OdVYvlZidVz%!2^@}7l ze9gn9Gp0iWw{X;3ez~gw1<@PQ4nxP1%9x3f_J*8YokO%E6qWcsT`I#G=%zZ;6Y2fd%P$zrwoN} zF;bWEYkTb!Pxh;9{{j~h^IY|>{wEZ7K}9PzJaNRdizKDcWB(J8E|QhP@K?*uJmzLt{lNN7}H8u;aKH^wq7 zFduyHXznJmzXbIUHN6yH7yfb{Votmg5*);N10uZDA}@P9M|KjAbnxkGv6A^-N}m&7 zw-O|@l`y`>8oaCB&|*afdN41z4dxd#w9DX6en_tNe^D^`DwGf52N+I=VkRHV5DUmT z1^OhIT-P7*RKQ6D$R@e$g*HV< zjD!XY!@)nsEc!qSwGA#3!KJcnxJZI*=Z}zT|YQ(GTq8-;>X`h)ZRefJ*I_EuM z)5><`$imW;RKAg6eJ8~#aIaj1Qm$RY#sV{0R9fry%7Z<(eqAzETUU*x4`-ce(R#Z^2c64)z;dVw8dfkMtWC zD~zXMwvV_WpTcOiqr0THHEgYyVPp+6WjD|N(+YJIVjTxM&2z~9zD28rB76czm>FRn zPiKY1l8$2!H9)0B!t(?cMY~{^1d5)GU~9H%7!E>s+#ZYkJKv!MNBRL2cRyX@R5TVP4ieHyOhVSZ}YcI z_kVxcS-=o$*P|p>G`Ir_c8RymQbxocMosFq^D0?P#Y}&rLp%EdAnNbNho>gnsztpP z8VCY!H=C7e*SY1?mtWH(OaSk~#zT!;)3(m6l;Rkc_ENL|Ho~LxqvA}VA5AeJyCn%a zmzpaV>Q)`;DWwN&eKgK!Q%rK*74>UKb^x~q|NL}7bmZLwyS2A2E}edKa+K$E+sNeG z0loeofYd(R{_UP{phH(5nUvG)+0!mdHqpHEM^~5PeO#LSiF4d>J~if!_YW3ie0@em z-zaK2UTuA>L}$?2qoA|gH`{spc~AWbqHB%Vars>pQ6>Id*jk?RGf_2`X``J{3CnH$A6%9P^Q!ql zbfSTUmr(lLLTO8)m7$wAIlpADRTNeWh)i@;deX20h;)^+>~_S!Kp%%5STT%;^YmY>LdX{k*SslG%WuQ7DS!#YnW;hteO9M zh-vq9mwRy%D$Tk*r>SC-RS+s6@*hWUxX_hhLi#*~1IU=8!lb>xmJ?F}@GRVZMqN$q znNUud9|{pjxsX}}^P)J0S^x;SX?c#e+z!{F~2w?M8p!H zq4Eqf+i0d%EvmZUB&q$JB+XmueUX>JV=%$}P$3z_m~{4xW{-xLL1y5BpZZHo>)Ul+ z4pxuu{JPuL{`HYvypo;UwH%92zDoCQShME#)Kga%rT*`nPqWO*)8)P0M+1yCi(lbv z!?acXr#=WRe-X2{(&0RC$JFKf?VGtfA6zSP&I}w)x1P&;+?e*Dy{&2Io~8YYv+bW0 zzWsUA;~jj!X~K#6!_)soV;>EZADvx0(>^-0+$|h?rp%6Cs)a(@9=~vXou{36n{vpH;{g!H1Bgec)OF;DR(yqsi~sfvVEVHOvmm{^-b?7 z4Q}?DzUncXZnR|WQd@d6FgrfP*)1zP5V-hrQY^j132GY`9VwzPmx$q-)&g50Z>{I^Gn_ha;`e!;&2JbPJ2AT=Mc#vfdT86Mn$ahj^=ory zCB3tfu2r(%U0b2v)6HDW&%@h}9^o{t)T;HYs^qQTan(TK#j58}((8k!C#h@T{i6;0 z9#GwufXuibe-SsI?n>Wd8Lz>-vaq$ya?^9^kd<%R4l9_t&3CHfPvOOijrjxck&h zMMknEjObJd9cTT+-9Pk1oKU({wZJS!o+V2ZTS>`7yxuUpEaU#|GN1?b8N?4pz5t!= zSDs-jnx8onftEZ)K_r>DR)V##4uTG3Cg|hm&jI8Z4r%KP4--3CcLy^MY9q6z7)HUk z#V%1%?M@%~SW|C1E-j~I-M>|X>_<|%*Ky)ZQ&Uqgytz|1dl9q-V!#Y`@hBuKoMFh(-roMCr-w&aSQz?!3~+dk{K+rl08dA~y0!<#oS;zL_!c8#^7sOtQuw*1 z#zsfYDwz#IZ}p2Hy|j6(IE1!E>C4?07crCDt<50E zC9y(4TU)T-Y@<|2q;Y4PZ$7hihPb?B0^6yEhFv>Nr1IjHFK<+8dETqJFDUw=qB)+!z?dHcR(_|xo$++NIxJ&Vs-Y~(ijTB?0)ra@)XihHeTxHNTJ83cPG~lY~ z_3N;8xICSA+WNK$!xJf~)^py_|Hx=)Xr9<1dZofuXIDy})KKHrvZL3e?>(z-<)|5~ zfAv-MkBf-hksT_z9eE$xvlD8)GfR2eMkKE;TN^7S^+f!vU!A%JqYQoa=RW6+LD>)D zIX0}k7po&J62bJ@`NKZpo{F8SyO};mF*^&tLQOy?dwgwZQHgjI}pmIU8-0!al)|epw3)nnt15f(NULp#owZPw5NA1p@>ufq($Z4 z{8Zsrw9Xo7y7`BPCv|&8?f773GVi#$iOG6sDWz)n(FG=HcC;~Xon&k@Ix*p6(Umk8 znK&v0k%zRjG@CyWDXFP^n>Q!#eI0rA+O><20Pz3Q1O7VUd#OEbf^HRYB_xw}ragx^ z`$k5>@CQVx0$!dek&>`7h0NW45Ys+<_;3OLHK8LBBQ=bLkRp!Si&+g(4BbE~RVk

NY(=(U-q#**AgC5D7wYc|LCk>geKU% zIv+M#TJXEVK+oGeTu*z-z`!8<9FW(TzQTJM@~5*9C&xnl85C%=k0C=v%yXm)xxxpO+nn2$s18&6e$(=i8mEy*>ppsaBj*x9xZ~|C?{& ztwjNI_u2A+|8bEx9CNsFb-9(akqP)9N4HW@lqKEhBbQD`w?%X8FOXb zH_2~XHch`O;CY<*F7G^H$8Pwu^T@Ehbv|~Y_x{?@2=3s^?QP4vIZCSGj5=ZUW`!V1 zxY)04Ow{eyxBUm1IQMVa-na1K>5uNmG^Y#|{JZOhjIu?dg>NiW9*#ETZO(OQUE%a` zbzIwW4b9qe1#j(pE>>jM_!#Ad7^mj|pN?b6n zN%eN#Q_oi=G^q^B493{DKCAeL)p}j$<$$Hi%o{>4OFDdzcKx2_HfuRzqd&%PmsHzl zbL-Kr5hKUA7xy2uWOv!HvRp`E;@smlORa0ovG3&(0iCuETH3N6EMwa>pQO3j3&_v3 z+RNRRn6H0&l3&i8TBDV)f4C}G)oWq>CO(=Wkz=3w>-1p`9n;x2Z1h`e)YEnPd*9Q0 zIcKvC(rqSg&PsBp&Qf7odYf(s_|~p=|FTQ@yqwt%P}EJ_lH1O;%mI zo+SFdy0h~`ff?Tonq>-a+Qg5lbWLYk-`#IinfN(X&AQy|h<;Y4ePVHW4F5fa7caC| zpNB4-<@z9Oo%*ZDMD**dCfO@8@kS4=8vD8 z>^=P+L!8!6G{LLB-{UCyz;PEtC-MDizQ>|XfrN-zc=6pQ|HS~5*Cf;>+F7rFjqkK-I zi#C3Ga$nf%#x+0JwOsq@z6I(<9_t$!-!QT7udBbs^77@nR~?eJ+gz6l$bV^ZSt(R{ zFx=cIpTBXV;=$Kb8$w<_O1r8Qt@)(v)W?3h5gFI3j&oAiF&*})#@9kWW0THC`Df*3 z<$Pw(t#@Q)t>kL-L9FdgJ0BKV>VEC)*rw|_y-(hsd2eB36&5M$er;+ZW8z6;uI~1u z`dOhj!*4a4*&+d7F}vgJyU8cL%`g8ksyv+N(y%keLh0Jse#TGC8xM9?1wZa8;C>yS zaBcQiZlg4SiC(rgr6#4%F(Z~S#3sA-nLlZ8?{ilDqV25uqds0R91f5!JG&|*Fn1@ z_gHk2lCCbyN$X^r)r=O}TbMPoTfe!zX|uvUPud3LZ*28qTp~j05gW#i&oHKRw=PWO z+}$B`BtZAKn@!i7`m~L<-gk1cbyZHCiP*GR<-@|A*EIfGr(P>(Opb_k9k&tWIX}r> zed~7yV)NY>SD)&1qzD-}mYRC;inD7CHU!$6Q#;Da^djzN<;e}fP5Z>TE)lPiLf5{t zD_^{8&RtXLlHM41arpGemZo(7L+_siT#jVx9_eDf^!&PuWqO97la*C{qjT(>? zs*g{w_-M%Zk3&zArq=|6^%WkzkJZaxRK)UnABo$NFiRije0o+s`evNXh7(EQ*(0f0 zne&2~HMKm0sp@u{fh_3IU4v4=dl3=d=D zNJ*hqU)$u{4e!7&Eia>b4;?yWe(KbzAKbc7 zG`|UwFourrx@FFw)@Sn^r|toaCu1r|AA%vfMoR(rWAgIyg5bGPW)(zbK7Ra2!_Cc| z>Rnu1{Pf+s)wlvHZ){@*is&y9`2wU2eVC-&3r|vlv19c3X-y3s^eSsG>ZdoA{#f{l zZ{ppUm>9K+{eg@`(OAuuUaXfz?|en*o&&+jpL#BYRYQO=c59)@ppW8k1jW;lqr`Lc?& zJMDiDj7B2VMa}c42^AC+1d@;LIrLB3x0U5z&$;Z^^?G}xC{i-{{OH%OcuZ4A8mc~P z{Ky=g)^HN;thn*ZZFBz-AGc06y-k~G*NPnb=Z~iE1jCRa&w3+zx~F%YHdotqWR7{B z@O_)bsjGWaTAFWutnW2_i(wj%SH%a(2OGGG`^R;A1kKEou zl*_uMZXq4Hn^o+eyq&TW zhqqRD{j=#`iWwU}_LZMPuj<7i4Q;#DL)|uAtE{!PRg1lihtKNBG5C1zS}(SirF6Wq zsQYZiP^w4BeG4PSzyRHJ$E*YD*6Z{G$(j5k>zzzjcG{bld?{PMm3v%Vy7JA-v#vU) zNA@&NzBM+OyvhTkp7PE=nub8 zIPj?AbE^-bcb2GX zENpB-vB;!kWHM8P3QukjFe$V%R6cm{6U4IwN2D^1eMLt~BE)T)57+(q@@_L3PIilK zn`?PM8#~YCT)ZBN3eDp?n9?qX0zSzW%&QH*I?Sn#T5eKf=}qVL61IW zcmw_&e#^WlegIjjK)Cd0|1Z=rj&5$Xc{_9)p!P*eB_$^liOJMCKok^jIeVPAgImY= zK;Zr7uVfw`&aKCMy$$Umicl3ib!Y3)o64(cY5wZiMwU{AXXu_w>K+@60=9H5zQ2Cz zt96OYF6Sxgsp_N0*0;S48Jo`zGsJWarKF%vU`IF3+wg`DQ96 zriVrIt?O-}&~8!DUk|qJbx*fA^TkL-t8Guj@XgYaB**D99&(ei1w)&}*)OGQDP{Lp z@JS}D@06iEEB{0C?5ut4p`fH=;m;CQZJ?n5VO8A#>SnEAVR5a)!v0LQPN8F>f=9W7 z;`N1j&VL5q^qp$DLc_|Mk^JQTY3(gvJY6*D?Kk*#F4L2nce*{5vDfv`&Wr~)c1Y?b zJpcAf$tB@V`TJqdJmXWAv}31ssumnNbk;53W|O? z?BBQA{9T>B$`$r}?>)R<{-}}Br)0^ICD8x=XRB%M1HI^)P6*FX6nt*e)j)EZ7VH<@ z;tQcrzL;0U6gxvsxdaoC*=FU8zP^9lJp-H$20eM`Y9(s--f&CmsfCNG|Gc_dc=80> zeLOu7Yd&BgPkae2Z7$p=iT;5`+@5`ojt^@!AEPC;35uTx?9aeWK(XwH?+Eq_=x4Hh zLnZmyuBltzdzKA~xQ(Tt@zUVzd6@5GiR?0E6ekoGX)q@8)aZ53K=O9Ug~sL##*QL^ zpP8k9%I5~2J~oKXNWt&r3s~%Zq1;Vf1@crj5${#i)NC!K6;D5Y{P>TFAzikeMtYrT zTrbkk9K1q&G&(xg^6>Cb7wV#fP&(Vbw#8P*S@qe_+qI=0koy3-4eX!Cd6T-;+bc|M zJzq!foN+cWWp;5*9jqOxTxgUswW}Rz3sH35(i`z0JX@gi=Xl)di9UCjbIztziq zk9i7Bqb76A@nTNcuz6d@Db<&nZCy^GH0(S=(aEn9y*}n1mtJp_f1x;HYO}A)v7AeV zZ$u^AWYS022U6~nI%ni$)M()C_!lKT4pX+HZYl*0%w=byF0yW; z-X+bG(uOlfT8%`Mx^63Ma4D(FRCkAH2AHPtu{qM*};XDg9Gu-UJVn8 z%u}I6n3H1R;7FTr9KxeMT4=ZS>({Sj^2^!_FC9m75E`2N8jxQ6YDla_0nQp2ea}Ti zgk|S5gg+;VSErHla1r1JAr+ZROkr17S2A8A;kIz>ILd3!x!)WEoYVz0HU{3)?easg zHUC{qFRh76Ws}Y>adCamu+=Z(mmd!c3NAu3#Xm5Rj)-g$)jELC21C5IJ!R5BOd`KM zrOX@FPi~$QUk*zrAX+|iV=4rlS;x=8#ufnhgy1WSGDDj7Sx0$30L?KNKKnk`GFUCx z5_-Ca?;jN<64xHFCr$6+vWX=b;_?B=o~De#6n-@f-61R*gt}>{qy;A0$l*x#c0w$+ zkyl2k?Cp;!;(1L&)jwL=31JK+g#BE&jih@0L;}VKv2~2%#I!!jFDN)LK2V>F;l{+)5w2WPx}X$TANar(g~Lh4#R_KW+gl&ww7th_tRY0kKPjZfA!F! zbK11}fNja$$p-sdZLJGmA#ycUukg;O;cf8GI2rqiq5GHL7qeP3Z_Lzq6@N@b2e$n$SLFVj**u1G!z5IFm*G3}aHZA5x_iyTyukW(1O<8;Zms*zDwSkx5Xz527>z z=?i|kZ{(RDJ<3Iz+~h|<`0$vY6>rV1H(BZm*{VhvfxQDp+wmI+R0s(ZTH-9==H}K8 ze_ayNphP85zHd1Nsjn(?MY^@E#z186fvT(DLS&Px0V~ygHh3EF|+lf~H(Y#YDiM%`-e5qa%)R(4o zW~gvnfp8?Tx+U5u5YfQ61C5oHRYso*=YimEY^$hc(NO>Xx7m-(*?S+i=N!Z*6ZP{x z1pNbfr|L=n5sjCP@Qe_`-3Np@v}V(Yeb@5+^^u++iA2;0Ndjy#9-er)hWsReh>LQs z;uKn2TM6qD=p-Xrs@4eyWg#$*lzzKr{9ed`ef*R(e=$~=xH#Yv-(mkp3hFarlpdNe ziYz=lgt?Ykc5N1urJJRsxPVWppe?kTlQS5iym!F!)Y+c;48|*D0CnlorRES$NAKk+ zkyYkfnx>VQ9qHcj=;Rm3iABRf1-PtQ7uui#t`jxSpZB3MA-)ha6gT}0B_7Au;itoIu-;!z5 z=9~{0P-5k}(+M`YFTUjP?iPI~RghRHUU_!3>urNyci}lxd)60;0!-}BB~GaA_g*H) zB~zbX9K(6PE=KZnc}e#Qo)`7&yk^gH1euL=A9|jtWoSRy=&I43rd2iazO;9E^6l2q zG1dn|GHv7Dx;gohyTwnLx|_C5?rMF)ZhE$W*4`G!>&Nn8M>*%u3~*G=&L4#bgOumQ zy$J^tYS3*%V)=tubBukeffG8-jvYJ5zkVFasXQX9;|2R@ew^4u@CdD_e?iS8YGoSa zWM4lf8YjPJ4;#d=l)x_`V%|^?O9osmIP<={+YwDF!Ym}cF|7U%A3t6MPpU5~r=sz< zg~7hbT?c7P8h6Kh1GAsBMU#-a>AUSR3^E{N72-S*Uwvso+G@H%ziFecig_N9gq8>!_p?v;#!{Xbei4Iz_!)qP}>xXw4!2q_gNjPQ+6f0 z!#C^M6|Sz`*0BFhpgZTZ)3u=9ul3BhOPp3p>QK4a`=%D#t7D&+?+9go<+|N&pXqL4 z$pn+)QQoG(IJW%4*hF7$reUnXo|Xum#>7LXjeZ36$U@edYR+N5vGKN|@2U@#?>IO3 zhE?Yp$8mONkLWYl{+{XgDo57F!pchP<@Ni#;t@>@JU>G#W!0)xwb^yDd|tmd)d|@d zM<)Kv^A<>ZGch)%{m(0zla~9&gy+h=#y-(8atdEYGbH0skr9hP+SrZ|Tc6%y7kU_y zaEaTn5nkU>W8ygUieU1Nl(t=rw-5CH|_&`YFD}b9q*+PpND!K4J%udpWMj zsM@o@3MV`n;@d`F-E8}8nlkFy;r(;@2{B%`LE)%92cEyLdS|(!q;+Oq`qJeUK+mlY z2VYIH^O1Y|=8Cs(i2cHuX&xh4!w0r)LC-_`ZJL~QB})Bfr$ux6)7}pZe%IZncleLQ zx~|~XRg9T)p?r)zO$?q6uJBX04Yqu4w6*uRJ4ZQs6sNd(mRQ zteraDh>$0WZ;;MAg1?xX>lXML5oukZMpj=lKN}%Mdx-wDrGft!|83}|SRUoNuJBXk zYjpipd0!eP_v!P#H&4AT+Q2x>&Gdbko459Q*U+Ao<>ZE}Z$%&OeKa%Hs^|0dx02oX zjC}oefma%(J3}1!2Gadzi{7;>B}`0A2>kThwrwL<1{zEd0(KIp(zV$=vTs^i*nQ_Gdv7Tqn?HCzToAoDRL70& zEiYIx3-Ak!&QxD3)0i6}Jv_0Omfw$)*f*@=EhuQy(3;5xr{%CP%P51g8?{bg)lYjryD%KnB zUg5-x-P2thEqo@99fU1TEcRf-Xz}aEM+PP)p{u%)6}DJN7;muqYIB~21W2U9+@if0X%N7S2 z`&oLkw{8;Tk~&1}*8bybMeVA;y0y5b#V;`~Swe>(Uyood!ah6w{R(0|XuC_qVppI{ zdex!;^W4J)Hu}IGVIcWgtRBs@wm8n(bcSH5F}l|b>VXNHY{GO_S0vQix<&D%fZ1xUBFq3ejsP{4NGx9!H-QniWb<$TnIMERkV(?SCS1Nf2M&i0JiXz9vz8Gt!KloTMq zXMs}!ZYhUvkKS7R)O)<+pG&iiKHQA`C+sH}rJGVhEpdNozCt+#Xg`ELfO$ZEIhGZ$U$$ zIWFO>5yGE^;}Me7K2OkHK3sCv+JCxla4_BNH>#13eSMp~=|2~t=qYlJj}L`e<@xhv z6U!X--My!4WE7X1^;47o-e_14v)!@C#P$mQ%7^yaKC0^KuVS9onfyztJcNH^_4DV+ z(7XFFK3)YFDBN_z;(|yzU-kE8;PQTwA^tIRYsp!+ZCf$-ckZ=oGu%o))juzqcXQe2 zb*pm`nqC1SOKo=On%G0j=;`SJP;-?EevkBPt@twkZ0f+wrO~jTOaQ7==Zr!FH~jUK zRaQQ?1GNgtBQc0`$+^F7ZK(`d$;!$zp}LuW-DFqy=R(tR2?YZzy4XW$6SNerH}KFl z|K#N4F~Pr1^N}jQ{u*$@x``(X9u-8fr@$IXa*)42E&5GNwzeZhCcFRMT-oz~um2GAVgSU1Po3;7j48^mvC-5(;?E3}$7 z{9*Xz%X^$PzIa#Tl?MLD$8YlF{$;AP_>%wq3q|aVCxk^s-RKtg&cFZm<$cY+AO7Fi zv*|wtpT9r!!u%`ICHzlc{<;5Nu>bm)728<%Ggun<{rgiF|JkE){(t|m|3jer@Bgm< z)`z6XxdT`tb(iDWvjO4ZD~DSjgrPv;E;Tx_6p$#a%Y=#cD9l{<3&&L9K!R%uX_&S| zUkVSeLH)KIZbV0q9=)TNp+oWykhGcy-n@A61TL3{i=9s%H8doR6(n-cz{T9tY7349 zh72wwBy2#pWHn$1WVwScrEiRBK#w0`8~;O>ME4&n?&SE+M<~q=78g|Q2}9LVZ_#o{N?E6M7mUV;MnwL)0J<4ia=)YRZI4**qHhS+uulI+wn>{?y4DpDz4D8N>OfhSDT z!FNF2`EI-nNJvP4GofE(%xGB1Dh*-6oV(^b39%Jdh1d;xBr0!92}VN z%Q-Mwu)A>zaf`&b<;5n#81k?zV{9}oLf*8RhbQ~dJk+P2b$6RcNlRZM41nJO%UPFF zB;8M)3pIBeZq88aVqP!BONA_3xekc<3G3g$NK+JB4a+^o_M`W}hzJIq@vZh(LPOUy zG&DTx=rBY^JTM&p(gS8J%0P$+yhKcFT9ze|X8@##ovOzDAHaE`=4oS-Iq?TK&I+Mu zHsn#ys;ZRAJSWuz?7`q#asgFyOnx}I<4aCF8CAGr4%~4YxnqlXm1<}qD}iFIKl4cs z@Lep$yG=9xbU>W;z7(!p7wil^kaq4uzN9y`_tH0)*h% z%a{LnoI58#7e9=HPV>OQgNzi>K^C;e4YN|pK*0i8?2aj)IDdJZKb>nt+YImN9B!q5 zd+j;p{f;YD>UJ+iVho#RrHETME~Rivk4a=H)7>mCmZnG_l5!+AkPszZSXg+No2y-C zYiUUfgIIuO>ej;Zn3oXPDeGzy*$aEv#TL_?upcu zpR9dyo1VOGXTLM80uYsY^7QFG`;zkV^2lHX6P%)7U~Gf*O4NJy?9i1Ih0*P{74ZXY z`Ug#YJSAbZN7I&bjCfNePo=3}w7=~935Wr8fgnLRO`)~Cv@S;E3K7J@M!zZVC7~80 z9F7sOhxXU@q2i7}@=cE1Ur^ak9mVcY!sKLX(f;;|kQeP-gV!qlwMIy^>zAxOH*vR* zJ3@J)73NXILXO}8vs0sC$OcG484-5*ay5|YT|RRj@UME=-g0+Z6;@XiGDS;Z9t8EW zv$rSg7(^x-Y%g3<;WgWTzCQxINDRz=f%cF68;G0g11(2YO^sj(g$})Ez6v7!7z)qj z>EDV6c#P~gMB%?NM*DK~+dX1rsr2!$@ceTU7l+pSQ8ZdWUC!K7NB=fw{zQFD(_Bd% z-ZGsrxcx|OY2eXGMtiJ$3PtOd*9_``N&UFAR45hD*K)Yt66Zy;a?hH=1Jbwv1#G(G zz%aRXpIQ+L{snd|fKq0RqFunwAO>R>dQBHMmMUkh=l=|XGuT|Ku{4>jZ~dOZQh=X& zlaz>fd9Rq5*jq>=?l@Mlo}6Y7L^nU4G&h-S`|Zcq@rW)^m2nC|LGhrcXGz0cNyJ^- zZ)XKP5H!g4iXT~>!gCeJb4~nL>|Tsy=rK_ju(Wiy+T)m2 zJNKpPF2h~}xO##AZ4T^r9jBz5yZdw86};JUA~{f8V8=d#2J%+5$g73)OP8uzcL;Qv z0Md#4No$w<9WjgPJh+%PzHE|*h-;fwqIVFcKyt&iBlwpDU_dJ?B{eOL!~sN>i(qgf zRXz$TI`si%jZG8aR7Mo7H!+;(z53d_FOrP2JCQFSdXPI=D>7D^cdRmunft#h#5WudBZ3Pu%P;gL6@ zSP+l(JAGh*Nwas)>9>2gY~QZl=ji=5)bScoXUXINpDX_ z4NB1!6z|;pe5ySO)yPRzH`IEbSb>l&LSn`3@2%YF0a)jo%7X?r6{^ud!!QoXim$%WIc_hY$4j|cqf z{j;#0LUWRN3`VB{B6y?9IB8G6y>pS`4P_J9L{b}A0tgg{2nh)}0Ra_IB&(9UPp&Q^ zmuD?4{>g+}uiVwJ`FQZ=et!_|!qA0a1VP8*+10_Y2l z54Q-3+O_Y}PEuZOapsJ$TJ#P9?9sbU)f(7O#AjIq_KkJ z^x6&|gK9Dig=vXx_a8g}-zF9;=9^zS3$h-c9lEaO^Plobv1M4e2LH+*Hp;}Z?7x?2 z7xlkn^s2I~mIfp*_>Zq1)Bn0S_-`fb|5b(Y|L#MQigGC0H;#D?#DUvy6pvdMYPtu`oM-f20 zhmBr=mbns*oFFhuvEbj9x;tRtIM$-R+SAy30^0APVe-w0 z(Qy|r7V#*rbMl({x*QY>GBFNLE5C8!4j^P7;Z}|_lTTb+ylwE{;&a)^Q+w|6I4ngS zPt|5qfL2QHky8++q43K1QNwwRu-U;wl}E>GvL~2j32v?i7~$b;6KS%5VbJ=VNR2)^ zsZRPkptKnXkj(t+`%+H)=mBusNJ|dMJYmd|)FM7UKBWwO!ax&Yh1(a)=N7tB^Hx7;K569!ebzZAr2kb-9*;O z%E&BPoH0WG3(qZ?*rR{Vdvi4|E)>0}O;c@Ph|C$vi!@ zLgN*EdAKFd?D!#n27c`>v9Yqd_6fgYg1j)q2C|pokfO4=?%831u zb!d)+eD!sM4X7{`0?)Q<@$ozCE4`+(!S>-m8~PK*2pXqwKaA9p3+;JUR~jquZcw7Z za6!QSjgi?$RJ}1Sz9pS+(;`ScK?!AUZ@<%?_zs}Fn(TRlSq>U%_o>$xOoU?!V|Ey> z!K+rm9wA+!C}&JApKeUK)K~~yl8_zIhunMcU{L=1SFhAl%kX?bKe&W$eZ1T`o$s@m zn`2&)4Ej+ymbzyYDb?p@#!0bjEyzQM#wJNovAEGn&Qt5Ss}!6>hQeGya%QB=(T0?k zV_n?7vsXnMelsHz-Xyj(U2iMQIPYR(&`@ACq+cKS{yk^f2L%|dt*vcywyE(}ja(ig zhM?aFYwVmy^{uTONZH9?D8Z#LdPY&imDIW^%>vQ-G59%aSXra$U_ih3`|PQ zKVo}GOa1G*--S+NM|B&`n(4-dR_X*n_h{l-uZxZ}=~6T&_Y6CaT(y_Woi z&#S8`%B;-eBJ^uNqB;ThTm^e;{e}&MLkELAzK1Gb{ zq2H(01%!MI3uV1L3w9&Ej>pH4OnLR%P+JoB`wt@+QxyE9_5b+|-pFK+7Sfcf=oSDL zyq6M40x9-x=5;sNutwPJ%Q?{JKkkvcejX>GIMG#G9%=VG6KQ-F5 zU1KEq=RkOyaqb80x86#TdD*R7>Ci}kd(GF5raR+{m!8Co+7gAma-?j$ z4w4G~U;w*hT!IX3VoejNsjKnv#D+K-PYKr^%obqX1O*?`OU4`aN5vZ_=@9Yjmj@X$ zhsV_uID)5Y-!~M8Kk!bY{Jr=g?Myv^^{{ZJXXa4`5mXdH*s1fVDK}t2V3?R5fWk%W zWeZ?w#K1ilAe(>AK0secakX*wKrfpX>`!G(R4x4KV^0q{-E}JKwKdgL(J0r--Iuo5DIda$5BD7))rolr&pH?+L0*Sz;M!%Fbs;- z8Y~ocGOkWlBalY@M1jNL>gxs|uLdc+GAbgHP(IAdW5ppxHnHG2kW&Py5NwH1z?Grc z@XL|pYCm?gc$a6TMH1pDG3-EgQX61ErmKbO>a+12v~G zE`uM#Wh8e-aa-`>#f!srzF*r#Ntz0(*iwpSFZ|Dq)wI;rSdS*smf_3>CVv}x+Srayj`SMj6Vb53y*Qk81rJT+3xG` z;?=FzUM1RqpeVM!?SBW`|Nj2PZM68e|0{t0|6Th09~bg}*DCn`-IXMr&$H_g!**Y~ zY}vX}LAqzbKJu%0f)RZDS-|q*>^9hAt>)ke!gBoVd5q2wCCNlW`j8beTchbQ98$?h{v4g1{m}Oy zjcWvakj4M8Ee4WJP*4QiAp8-MWjYU&U;%+wt}tUz1y@HnpawJD z%|ht*LS%1mhEA%vwe=-zx+GBSq%bVp@CfJBX`HG^Fc{U<)j{|@4vZ10;*ndG7*8XZ zAr9X#eC&|x8RUW~8~UEQw;^%HX9r;z4o_72=iS=!%e^bAe(iJ^sreQDWy2P9VOPUc zoM7G*kP!rPAjmu$dwtW zbd;#aj}LMgF`r5ar~cfdNfC`)a4{3p^Tl3qIHCuIbH;+Jf;Mu!( z+7t`wup~g9&)&Z`rC8AaJB~iax>ukz9!7VVfg+s%oHR=9U0qBR!O4N49kDM)xMBjN zqz%f`wcEF|$LDYqsP3-y|6aJ5oa0ecwd=oJ?ft0-xC0aDdSjK4I2hI6ktpi3dqrPS zSr%uXWRg2;uj7^^_gbt>r|EB|V9bE{)rT4ZpSH`bAVig!)Lqma^oI^+zbqXhMjO{Q zt28g1Uzi!Rc}2udh&dwou7qw$^d68t5d0Byjs(I2HF&Rm2R^+IYmndY(H3G4fhKt$ zU>W?13ux~~?=bWO2)*vCFrZ5S8nmZg--M#jD#$beo&fwOOYCbS4O9-Rzsl}s#EbTyifzy#^Njj!zzmZ9LGIWY+&-TGMe$`v#vDT|R zv8aJ9Tx1uPNbWR?Wwe5heg>QI;kZ=hxy{A0)xd!h#0$T*6HSDq`aR zJ+R{ByllfOISoLY{b9<%CTzk*IY|7AXR7-S1Kdsf})i6Gco(c#EIQ`nC zuXg3cNE_iAUqnF#m`?lAiGK+0Qx_=!B!vFzxq4t!gs(#}Ff!_7#+1AnK{z|zmc0RO zSTI^55bKfw!L%3J-|8&kk5~;7G6tYX-j75B`1QWK`vczs^});8-&btcnZF#-`@pqL z2Oy!c3ilr=WUz`r|G<(a%Iy%e-|`VP6RJ7l&Lji|_U@$uSaewK-TCE?*!m74W>Js8 z=nn=I#Utl&XIxfD1WP#d5l{~mkgfIW1hd}(pFxepYkvh$4njoK&l4vtEcgnFvAVH7 zL19A2*$;=pM{s9KG;d(DHk-@c%(zqS&p6vNy$qCRX!22x2w_pbMCIg6tSc{FxXG2%k#uP7HE78V}FT6l+~l0YMt8=9J$2$i1%7vSqx zU-gsKOTrH-iT*F=*|{o3eO@tL?hR2md@$i{6hv$k=O=CeQAV1KgZls=39O0`iWfr< zcOs`z_+j}hN2Fi9VZ#GAUr|xfXQ=ln7NqkAY6!TA1j(mJhX4o7v#u^<42CmO4{>)pfV=?Y#6I(A6`&LI`0--~bd-plBOcWj`?fuaf6yr; zeR91VlpE?Uy}xH>E@8)5$E*B78I=SdNt|!Hce5gVaG>5@1H_$T!6fG=EUnq|Z2KAb z@Z1kvTC;Uew#J>SuJz9lW0|NVcM*Jk_`Wwv5_;{TGp5!F(O|6+ zcwC?ta9+wdmf~v6-LW7NHV%??1k&A|{O`DF^`W@a##y8@*C@k?(2{hjKCV#|g-e#x z`>L|XJXoD;z(8Gc0k!jG@On!gK74p}m@L$&MOh3hJ`Y_lfTGOydC z+E?%@8w`D9gL~&BPTGZ#cCtMXw{lB@KvLIA5d9YdA5O2cr)7o zsHV)5Lte+MFzG9rF8w3SIqH#xWa^UE7(Y77l!h-#5!PWHM! ztcK~Rcj(K7pQNRwrK3_u?C*iq{H0)7$lXabb@dyauYjJlM5}9RMpvj1-jKZ_#;_*N z;msJ>cOae27_~7POvdSj`+syYtks27CzD*v9QQpAhs3|1UWF%w-YEJ%2+4wDeb@l0 zWx#3K!Kc4^JekyKU|2H%=%U#R7&ya}{_yH?fN1+b%wj^RN!FJ(EY^|8E~QBMOsdH8 zH`=|AA$b!S>WK<;ImAZb(8UKZS(xkx(@H$eFl3>`2n_$7GH?%Mn^)MF^uaknTPy1; zA=4!Rac=a1Bk2L4A1Q9f(KczlZ-ARr(Unq(^W@Sey7W5g_x^j+H3cr`S!Y!#uwu+k9j}I4rVUmH4 zDhPWjbjAkfRZYn1&4lI>nfB#}nx-ZPO4q~74?};c8tWIL4h4i>0~}9RO-ohvDu*HS zs#P3Xx8<%iYQ{_PVRk2V)x)on-f+i_@r%V+lVFthIqJWQ~3{!v-U4G}<%uTFL07)_OUF7mJI>q#e&=$%GR z&CR&H!??(h#Zw$nk#*}DKi+dNG{vxJs<$G-o&9Z`c{FK2~|QlvXndhjkf3pW1m?jOMr2wk6{ z08i#=ri1KlcL4cG%WQO=Sb22igEokNr;bF4bJc5(nV&@B)+kZa%UEoZ{qvll@__^Y z{H}BziSCpm4{8w{DY#82q8ZOU9-7zq2X#-|j6~`pBj5lEvzk&?RSf}oH6k)PGI9$t z%$*biXgg8u`&SnM@(M#$CuCZM>`HCg6VhIn%m(k>(dy393QorOJ>+_#bjsOB-&R*E zllgoPFRw6k{Qen&%2MRe83AnEY!MGSV#S{YFT%4z&!h(M+oZy=T8BRt zb7?&PUe0dcN@^O8q`1soSh#%K0osa^p^ZRbrpE_g8gh8-a-QI?c?uw zl8n4Akse1vi|ob2vpM&NXeS?BPN8z;2ukeco?Ui7sC5KUw{h?s_Ts^@scAT=TSZ`!)` zfcK`8xatY=e_CTdZ8!W`KiCo(MU>1#cDCs|IGrW zsV|f|3Mza|^Gr&%1_NJ85gq*st~--@H%wsj$UPshQ0nUONu@qSbe9<&IZ^y+6*1u6Se?k?w=H}(y z7`WYA>08+;2mROB1{WDL&=5ky2oDOKQNq1J4=v@~@SAw)&JbRB;(P22L{f2B28 zaiHzy**EF1AM-|8iCc5?x@V-5-TL?c8^dZn+?=!CSTSgT?Yfp1Jkp_Zde3|`V+%1l zxU+$E5g%UXuL-k`-XtYmL)f@9yfzINh1k4v{m}!|#b^|d9l`ww+;tCJNHess2sXaJ zc~?MK<(ywl^`@9^`mOu9rQZ%6FE{qOJ7uopouvc(y_e*-9G4@gA!JhLzMH-*+WM@b zZcSEs+(TW8whMV9sUPbcyAaK-@AK}kcS_1=&S8Tq!@n5gow^8z#Mory)e2|nWyIS{ zy~K4+uPqLz&d(wbGJe`ik)LCzlD7K7^~%!Gl6X#zm}cm*SBZAW)>c=)iViQh;zMpL zQi$su=P!8p7^sXv4h|J@l{-xuz9lzNkAs$L-0d?sLUW?-;Y7}@SyJwF+h?av&iODB zCCN&A_>&<%Q{W-c-NdXbGIZapIMr9lsr=nc#eurkb}%2Vyh*i{`}XfYT=&rG%Rd^@ zeswx8);0C^kboJhy!*|YwQ^5a1xKniZ~r;VTr!(A-#m`+;SX1l z%89OyJ8!oPLZ=c!K0l6J3{k%MqqPqMMl@1TB@7h?Pcx`s{a#ns?F3+BQ}S7Q42~rp zEiffPZ2~z+dQbPkIK-!C`LGat9oZeym@77~C?0VL@>vFVEEKA|HWR*i79&sOjm$1*Qt^l+M~0JaX?^hAgj{yJH#qBg6g6 zqR{Pao4?6%JX6s5H?re{I?H0WkTql{S-UpdcJJXEE~yDFhp7cmZrDP}R%W<=-@X}n z(dTE0UQwVMpPA3X-l#rU_A9Nj-~r1RyU)DS;!{K^0nM|#>T{?&#@%^68fId zCA;Z3ODxzAwJJ%RG?j{$ic!2X)x#S12mjalIM2+ld;5CFa`qA3alYnKHl|H(>lbc`NyU~_Vt9$hg2lx7`s ziczup@#r%8@2>0PS5kFqVtY>yt9tLS6LRY0{jnuAR?e<}Gq=mEr6HA*V;g0_e^7Tn zT$@Y4HC(N-EQ_kL>|_4Co4#)?mnX;mPC34G;bn)vy6w!{R(I~x#LzwexQ=vv(MVor z^TMl>8)$aEx4&vZ>asPAYv1wSbb4XWDWzlAI)!VDsy)Lf3TL}Ch`nnN_E+JswAHnc zOi@Fq<`Z?NXgW?sWTwQomqau4m^PQSSUdn_IQ%RLGyKse z(;sXa8^wPuqSr@$SteArZkO@C)QO9qoLwyEX2psX!Tw#R44dJaWYxs&amdnu*dm(6 zbRMdW969o4u|?mx^{i9K;RHNEXO`u&u!P>*E7U&7O7)qU|PDT_&r|IoanR)cUQ zzXwAG^r?Fh(p}5}@3m{uLS%WZP>Wpy6O`E&m>(Cw5sU0X=ANG{RM!ooz3Tmnim{D> zP9i>x3$kg}XW{Dg>({#}4^iS$7{|B(@i*rOT&EF|Z9M@SiaT##Kbt+1-cxv4c_!NJQ% zHc!>;M=Rje7Y%;3J8QkAv50me+Ba9vIo&mN$B(sQ{|+u`n_Uh-VO_C!;-#|w#`7Z~ z0N=Lx^mgBuj(^_H-gP@DHB1(bht8wHsr~`&h*PvSqa5tv!QC8U+sX=tn;%h6K^k0# z*2pMsccTvf&Y>iN@EX&28yV!AyW07e9Xl%I5=khLMu3ZEYcqq7Sn8?RS%Gc!F*;FS zu{@w4vxQ%1eaCqO(%@kohd3^jBm|ZcJRoNQ!y(RNUoT6 zQnta3=z(agK6@$WN-Z4mq z9BA5gW9e^}9--h&|Ip#>1`HVB{(c-FN+;s(6$Y`m0>X+}EyGQ2P+K;Op0U9-{e|m` z^$k+XMsEE0EPePcV?Di{#?L=21z(9uoX`EoxjJsVft_gGr6H_F(dVeggX0%So?t%; z>eG1McNkvgk#6($_RcU_ik;`7(9Cp`S!-3mM*AfnKN?W5J|n(!Gs)bSjI0dW-oCkI z$V107x@*TTcy`@ppSD)&q`dYcK4z&7i1yZz6QoIDIKitUVzd)=ZAc?0PMfwTE^aXR zl(2l^)tkYLiRs(12vs{^BGk8H5pQAB&oVXXdx#%J`lO)+m8=+u0LwwUcj~Xt@7fMe zSy!ZlisyLZeDKyf3RFDUOSl>S408IdcDkph8LFW=3cev-XRKye>*B9N+o?Q^=Hw}y zp*A8^63}`Y%JdpI1(`!1pAYu$r_-ZqN$Hq3Zx>GacUM82!0mW!KJGmDy`N%pDg6$d z+8uYk!;Rwa-~NhI&iLaia$^;bgESQQQY9ZOe120q{5!$} zhFWFlsr3V$H~mxua5__M3*Aecl`*{z8N6tzL}zln1KM9_*Di^uQyCs{db>eBe+xV>@Q$ES9=FJ)K$=tzRmCgGWL=wyU0aPoj%uMtWxlf z?c3vPE3)c6UtT&XEg5;rCA4aRV}a6mGvxTob`vCEfvA2kVeWB%%XqHYAsQ`eah(v*n^D|NyUu)Z~yeoKfiX|g7u$T3o{#D(gw)u041#Ro%b(=+^>!DRc3h+M81Id!8X(*t?eJ#?od3mXQj& z?tQ28@HJyH1e6=WX9e#YJXC>!Mkg;EZ1I)xE_nQHOmbxmsa?{-a-Vm77_$NeCD*uv7apaw- zq1&4vu^7b|#Ye9YJR$b+&VWM6t)#`nu!>F%(LbaLhE}F>h#H#`rYBh`$rF$(ai^e8u@7Yzk_b~Zav;gN^@$5Wq;B-qs5NIg&%2pk+jzLIT+0fJwwbei?W_Z zjh6MkQ=2lEEx^Mksa>A&s7;0S(~SsoLN+*wRH=+W>k8(#k22O;XHTRva|*Y0W?U{$5RZB z|4+x9HpVze@eN7UpB^v3MW!Ll1|InX7%!pLQXUaZ5g+n0=i0TgQi;&0!a+kZcLKHx z*4@QBKggTL6$b2U^^bmBk9qLcf|GCe&IQqawZm` z`V+WSkL%w1d-41Ck!1sWM#6Wg9jn`w<|O&PkUd5aO6mcbN+jnkz{x!?H-cI1>o~jQ z+m`VA5a8oUxkI{)cf`lTD{)zAr-%FO15jsL+x*9gXIvjf?H#i`y}SETV2{})6<7Wj z8$Kz>J!y}U(n7~_@7jK~BPZVzh03~_k8FR;IFX!f_wx!SVa$IwJsftv_2@wD^%L%J zpN$fuJFC&OdbE(OA{;*3BIEmqFP_~!{J^OGlggf;$UXtZli}!u>Qm3wHvSKwXA|}J zO}J2u_ZF4-wN^9MdmJ%yk}PaV+@mn_J{ppv$1eFbK~MW??a4rPs^)y}V9?J9vx|v} zuCxuDjIZXP8yB}wtR=n#RPhzQvWZX42=fp+N5D_e6X6_4dJT~Xn=Cbn(ybHP!uRR| zusC{DTgy3Nr-Do9jok8m^0GaT$#Z3V9RA_^pgUSX5Vj<$Crku-UKPdvh)i5DfBw-* zAh+>I&h#Mpc7mqK_Tw!TU;^4gkfl3~$C}BNLBCiJu66I|m}^hpu4z4XuTxF+Yyn=q z`pz|`hf`t%8rA`A`_Jz~KBDDAJBpuFjxw4>%ym%2~$pMnRjjZ4X zrJ^3dXr5#Lxv_fizWQ(%23Nx%O$oIpAYo5H2??4`=xag;q}Up1u6{n!+j+qIVlN+` zUJO_2C)X`|cqcvW(mt}+&CNvn0rCF^6+mCz>arO|cuJz!9B}W?h%yR-U3dDdFe5yx z$v~y_OPkk=^tkp*z~6a=jo}L56CMp$9~=oxu@{BNUkLoAb3}bsR3BLlq`?V3l2c$( z^^bJnz$lAsTYRvh=3f}3Uwkm&M8Eex)61s!8djdP@l)X27bz)aX3`CVqAw)pR1s}Y zt^~oq+O=zkj82fcRN7^uziivq!uRCKeQh*)@x<$Xt-PiD>&t}RqbJBfzMd)q&isVmPY|woQe+QFt0+Og!R9F!^x9}7>hYQVVf3cMR z3HRE=#GtL-wV?QLK=k>XIVv7wEuODl_2lfo5Z;Z44I3tXFsyE^tcOeW1}q_I_2j=F z9--(uR|AG|;^=|Dj$5W{VA4@+5;dnK>cwlkbm<5%pIy9z?tTy4YjNXl%Q-+Bx;pT~ zFyvCwsn4|;JfzX{p42OPyOi2YM^A%YLqn!vJ(-z7c2ny&g=mPx4J(L8BR)XI>r+*s z6mp~}g8WIueXMG5 zgXg!0GyP*fLxV&%0F5?qgPoEa8{V8uPjsOC-_=-g6kVhYp)R_5`}SYZWB>CT+c!5} zKo!$P8D+}m>-1rE?cTi_ar-}{65}aWByU}b`gaUZTStL*RPOMYh{~(T{+r{y^9VGV zf>Qxo|0KyygE2vk6^J4g_wH3L8yX3Nb3HAsAs`n5Go#D?2!}uGUnVlecHcbBL6L$1%bcLqAv?v)FEw0 zS)%6Flg&XyPF%XGHq!P%5muFdH$EKmO%prO|#d z9qFiOZV5wFfBIaB1IgOx8JTVW?avtTKIhsumsXyCF;Mt_&91=G@`KW&r-h8fCB|^W z-+eNZrIy9MLbCCaNi~E}yg7pF)XGhf4!Znu{r2QlWgVD_o|Ic|X3ga5yW8sz+}fv7 zbiw>(H1+lDUswLV*?jMk8#chw2W|eeBz>w&Dri1?n{BNGmGG&3aN*j{|C^@te=~6S z|D7tcc1iyG+Ya9jau~a~LcN&)Ap5Ac2pg=U>z8le-tuh;fkm8FmWC8@ zvLTk;pgxjDCM3#VI;a?R#yzQ}kuC$f`chr|P1DRTw0Y7o9}`I?_TZFhf;vq2ZkdOP zN8lCcG&$Jjvk51Am7GDLREDj3Him2WGjmd|a$E4ZsL-QPomRd*x&IN0e|cP1o^QaT zobx-OY{AO|1G=AZz^YJJF$Q)$H+Q0*g@P(uI(8Im7}ON)B+gPT#!igM2B5hAU;?M) zAZ(a4zbk%k*`Tj?L3;2`IM`CzD8@j1xX*9J2|J7ojo`W$ymJ2eHCJmdx$ZsW>f;7O zU+}85bS3w3%k&X_xa#Qwha_*jvjrD&`j%J8C2QO~Jklt^($B9oCrd&V(+TKW)UH6v z(VnA;IM8{85DXh1cV0$e{l6K2*hJ-!#mP{(;n+h1%|JU}r8F9{)g?78O&mJtC=<48 zI{HPyL#gT;nzidbD4L7&Dt6K&wF5WIL=J0(m{Q@)e5{Tada)K3E1>1-DY!x@6g%L` zq*vB|9nF8_-=avP(lMXC=WX)m=TWbQlibLS$^M%_`6mNqz`R9kBT+)^@1*ASdJpX; zGO5^n;rX7Tc6mhm1zIA}&?%e|Iv$5Yos+V46w_ONW1a|xuuS7`e z_A*a{nL!c>z5UZ+pt}zoIKB8caJ`GfCj5ipM<>t(Sw}&pG=X|&Af%`k_7a$;82e&J z(3RmwWWLvKc5!N|5Vu5bj4y0vX##^*I`s7+?Srg{n3PDvf>jkNo`F~*IMh^Ev>&ly zg~u6zEsFxk!d}5RkqLDNk#IK)*q8?(0rKVoXk$)~p6R!BFKdm#(yQEKvY&qA)Gh%I zl5B!}Y8OI`75Pb7PqOZ`avSX8R-Ei>T+VaMELlvEEro9fmLRe<}{K}OvAzGAt5NFayDqbXaw9>~mmbRXM|NR%TZ3i+!ir@5g_4^Wa`BpeE2lsynooB1l1E@ktjz95A%#v@uY+v04U+@K@tN)o!>FxpRzlxy%Df#EmpT|s`*qBq}wJ*-NEuCf{>fN-ow5HX2$pI~- z@E=FHO(xv{nIZ?jP*m1n=UF~jXnQ?)VQyDy<}ZT(ojZ@wqE2-*1}gGVlAX;_1R!cJ za1Ki6U&zZE^B7kUB|+^5$Byn+pAEG~FsM-rsP+b0+-e~iE3H_oJl;vL37(QYN8|l| zwEf7ng^fwLn`5!xF}4O}|Li+euWz?=u)Ok}+6wAqy7lz~?!Xi0gqeNVeD!QQYOl%D6kEsRFvK%uI+ff+i2X z^lr&uu5Mfpj-xTRK0L<$^UFFii2|hM3i&5HYdn>l(5pfrCHMNIlxDhnuxfS^B6Ea{ zlHxOmYI14Zk)!rcV|<1}G5pJFOESr^V3}K!`D#@@E2G;heE!05j+zft2L)EVX{Nst z(6x@j(lR1!3&KE|?Zb^OUCJC`89V4X0CO|NBQA6AcovZ{Sxq^qBF}p83(><4mW^vB zcqIhgO}3So`g`Y`Te)JzOD1}u7^dUZ)#hr;zdGh1^ zEq`EMvX&6kw2N*hx>w;ygS~uc6@-$%f=ZkQ*6Ew?USze!87Z$iDJdypMMupg^%24V zqIyVUH7L3SA=LO?Laa_OoX7;F8w%v}pS@ib6S4&VhRMP1o$xL#z`!%j$0HR|X z?%3Fg*P4~{L@*~#tD9(SF|E0Rw)gRZVoaxwF(^#^Y237BNVit7XN23KvrSwF9{lVR zdY5r$Jt@d`Q*Mdvun6iA^U5`u{2;RHyjY9BXs-I3+gc)GOiWC}(qfr4L|PyNLDxbo+WAn}~rs<@;g340Vw;&DIIT?a0%g>~3_zLBL*PW&vhfXdX7s8$ViYs*Yp+l{} zc|ta-h`Fj1J)CD*J7I%HZ*}9^7g4MDEi1Z?8l?iMcvBSUp#M^3{5M5hIuSw~A7BH% z?7`8i0H8xUZ-kaXQm7;4eaO{0teoEF<3I_(x*E>+8kHe(OC=+dz8qHeC0C5~h+(Z` zrfnc{djGBsmo!(vmrn;efRcMZztQBz+sMH>N%AIUI4rXyoM5&xl?lPOe_^q%w!1gm zO*r<1n3n{Qvn!ttrSmtu>E6@y?EtOsacR|i2}bIM0cR94qDuOVS>FWP`*GG_n0dn= z-Mp#4bYAw84_v6YU)&^VJIfl%`2c2Ypb+-~aI>p`@8YN>6W)2Pk{f_9c+j{ZqXO;C z4sUmh@p;wMCE@Qbx^Dc=Yh)+zk~dkYv)#M=jz>(R*d8mHPX`=`Ur zEiYbe0;QAxjii{GG1kzY8ZaY8MMdSy_2#1|s_aqX%r#P!?CzzJI&^$zvma9!qerAf z(b8svwBayXgAgh(R76-bDB1QfLz5=PNkAgR#M zU%gW9Od}(sN*WZ0ESR(lKvm&;a$}rsQ-w3zPVnv{+Ur9z>}I*y^Fou--XAPK4<sBFG z8<62_BQDDZB~TJn%?vZO=7{rtBZck9XLCz;bho3%@Zf~@z(RoSXS zp&f+IE=q2iJ}-FhY?9}hV5dgD?y9JFP?6KO)}D)AbWhVlo^epK{=RlE0)|Rz5>|yS zoJKNS4Z9f|Uuo)<11dw8`Lj_-+~BZ^2d;C}w(M~kD%F*kJn*;%Dk@6Ezp?Aq4MHyH z!W99%1;aerDGfhOC$fr%y8odlYIJK zWtF8MA)uH)zqC2StA6h42#^59n3Ag(wih&+bWnfx*73gVT~UC4`jxx?U%gMwVgAJ? zU$TB&2fQ+*Q5Ph^CV5+pG-l(`&=JaXz8O3vHvO|<}OMF_1m+pexNH#J%S-+&ydfJcDS-Y@g$Sc2;j_jkCs z^1Q=K>Tg8HQ>OcyR6rWhhM|#^KTGpb9Iy6U4ib&~f)}6YK(Zn zn%0mbP*0IC`26w=d<&=}mQbFlTd}9ca{~<7upEBVoKKAOUZ3Y{!E^!MAH+o~YGrvH zfYa-2|Ncp{Z6&0K!XmAol55WqvCW^^+1l!nYARlT9Q;5-0F!nL9kv0dxW~+*`(mI& z<(d!AF3jjT>dcbUREiPq30QC}e>`VW%j@wrIsJ30?tB#0jA{Jx<3qFG5iQWFegQ1j z2-!?>%dKZ^e)6a?CTx7Pvvddt#P}>REF#02_!PsG@PBNY0hj`X&Uj~$vu8n&wx};a6O895~OZqZ!>7CXg=6p*# z!1tMc{k5wW9RR~Q&fN%kvYQ2ml-$Q<8hJ@O`iQ+Yxh)5K7^k?YzP`TqGA$0SKK1y} z6;BQbt+rdE{j}lVwiQ?wbJ@RrdTs#uTd3D_s*Vt+!wh-gFpp*VrAdc50J{3aY#O+P zHu7}!PbGOux&N+(<0FB*htwYuQqTA545D1Hnm$3aTannnk?ZfIKSsRo42xoaGOL0_ zdmGfTGDZOM+vqRl_=MYO1?tv0T@WhMsCqhmkB9rGU9|_=4{X2;Pu}9zucrXsePGlz z*|BN^#-!rlne98frQWH?_lUO?77ZuN&20Uzzl)23M?D9}JFpCOK>PT#H_7b25URV= zwhL*oJTD&egv}lvz6-&_27-x*dxKoyq|Uvk-1TB>a(D`5_?G{gygt#py5*R3Ow}rZ z$~CMyjX{ePfh1rNZb^=aiMZr_-`m4Sw*XlrT3UP5t$wzkUaHA0UYbtDH#UmQFg^t# zYpBuG$&EXdN6*oDG;!Hu$WM;Lw0lwU&n zE}R#D_)R{ic(h{2hi0KY5%KmHT`o~qke)Yhn_Y$&J9Ed*ckIox7@#55c{krTo zY3u$HP>_QDLsIAEL~qbASgBwX?(`FjnP|w zN$e+y>8&zDv%xopuLghJwryKR4Lv?x(2_TRH5WD%*0?_J=o;js;FTg%D9XJ}h;>Ip zDy@vDk;V4hE2GWBYZaef?0=|R)mJ~qb#K#e3nuJG5A-(w;xK{Y=t3(1;>^0Z0uWOe zL~zSGq~jSo)}_wVK{l=4{6gz*q01Z(^|sk+N5?J&MX=GDpRXToZQ9*FMHI{M^4ACk z4AqkClxjeXF&^ys4k~+%Ehf|?S!@q-A$b4C%JWHpCok3OaNt6J1JVS$>>oTfb8YHV z*>jK4eKK7JQw&^kqcJj#iSNgPQT#2MT;Af4sSUy@NH?xQrJYdz*pj6%ZnTV$g+kSp z*K(W$`-qNRMBX6Tz*~+I4I|8E>ve&I!m+n5H@6%_dVbZXq?QyF?IL_f1bm+tKwa{R zpKIS`H1J(XuLXy}v!oCZdnm>rNPF*9%Y@-BK*?-`2{y?N z`vEsD?N0J)#biZH_{6wlUgC;>7cE-k@v*^U7e;dc4S^6Hes7ej-LhlHeWK`w#Ul02 z@m;8Cm`r77-(}pF4eQv=0tJxLXoS!TD;dYZ?H7U9uG5P_X?jy$I6HYUY8Wc~7ObH7 z%M#BQy}or&y-s{DKwNwot|1mEj~>PzJ*W$9)-AuVNw?1*atDjQnRrucnl~iZ3QK&3 z=*lE}jjf%XkhrVXn6<7jLrNgc$N-fi6RYmkNq$J9Vzw&I6X;!jUx?kCMC= z6wRCjlwatby(T^}F-d-XZtIbFMa^|{_PDzC`GM?^iJb;dw9lWBB%r#s)IrZ_mcbCsdv%O40VIdj)GTZ*U zF4(G@%80h$Av2VZHdi2JXos`FiM1QLkOxla`g|~1{t3uuWMz-T578oj_lqtj-$=Mn z^%*`T%~e(G#q-kp7gQl%lz;JD{K~7_8`O2K{CM$`$=9%ukiRgv6TW-pMf>syMOQy9 zs}qS+8G}dX7-w~zgYMjik;M6ihkKtm8i}~!)WVd@X=Nmr8r2^k`VQ8oY1cyqDp7FY zI%8m&Cp7K`t-FhI@R9vJA@~&{eX5m8HyTs;kWlOWwPCGE1^1}ftr1%nP&bXI9ho+) zTL-$SS1mlZDgX_~jJ(z4E;q>g;!@ZzV-hw_#q!MQ;kV7mC3 zR8pQ@waO~7^$rsjt{$5~6rHA@rnQOj2@OJb8R@PW zpTBx`xbb+KWPs>hiTW#9sKXyG0weX5#!f7L`ccEG@(b!KzQEbM-O{0{hc{XCcogTN zT|scyW^7|5*?p+OrKzw*LCTgJq$<|WUcj&g+lY+I{Rgfu30vEO+IxLfUlK>AxEFUo zM55}<(W;f8DfXgjZ~=$tQ(D8NQ~+)~D=&ssuXO$7 z0-pSSWOVfKI7CymLRWsPPQ~~#-P>o{7chG2US95#X3| z`|`sgb&S{B(z{lB3aPnv(iA1)&=%bXFFK93%WGlIl#kr(jFo6r+X@e!_~=MGXHt*s z_akE)Er51oqQzbijeW$5@Mq_z935^XTGmgOtESQh>Ox*1_7eM~0i8=Q2GlIsLt|$P zc8%Ob_QemC^)z6U_CcJuvJE4eS+pJ&$$QVK8m^S2(+T0vkPWYuy6K_tZlIv^=}xCE zBT^i(i89Uy&h$DNySYPs(3ZVe_SG2!y-5S$=--uM>avf>wGZ*^K1ev}J;Ua#xKpFg z?}YWOJH7LKFYoB7Hg=WsOB`TR;9rpLJL_@520U9#q$ z?%66bInnH#z!6IXw8|gCkL`7g{d;WA;po3_vPTUA78805t;C-n*jz+X_FV>?i^(A8 z*~v|TKg(sY`ciaFwPt+=4;rM0I?HzTQz%{cn=@@!i9mRnN>=`CwCUnhh8wpA(AY_@ z8s5@1Fi7Bq8J%<;=1A!+uSi}M39J%=g&l#--waDLt=ZgeNxyL;IyEtj@(xJ+`zCmo z=HhGhMVJe}V-j6TmKSnK?p&TwK85VDs0GQU%(Pors+nLn94$W@WLVBa4V5jb*KXX{ zj_lm{_tBBp)7vbIy?#*F|7DthrPJTi2ppVv^_+?YEoliD5>&Fp?+`vZ(oB%9%R(xc5rNM6b> z%Z&WM?U+vc5we~|Z|?b0a5qMY&_BgI<@;xj5QWfxhozKDuIxY4T_ zo8HW%tq*AL1|@;!vIm`-VmaXVvu9ZA%z-_7_6(mgfcbP>iZc_Z9Ub0MeDOev*4bWK zO&cMp==GZ46{8^v=bj1Q_50y-$D270H8nLUL8_KMEql1$Xl%OWQ_~^a6PIVHw~3r^ zIwNDTp7)7yS#J-_a`RJp+k3BfV>QFl^2FI86Wv;96#n}Cc0lp_Tb+wl-)%^qzwW~8 za)(iYXNn8G54_&oxwzcH(o%b7NlkUfer8GW;9BMzqDhf3*f_}QL~$0wl5VE+%FDCo ztqzCS`ugQdI=8kh>fT*@_C)e>1#IWD455YC4hB5;x2MV3MOT{^HCHGOw4_YW(l~SJ z#P$YFD7R(_xJG?!;T_Ye=A*z#5bsMjL{mkoFT8w-LX$*zRnNKEUk!3KXTN)Tj!_%$ zG?)XaPRc+lg9%VoMpgy8XBId*c}56zskNZ=;J1T+ZCq^#T6rWoM5p@KFI$?>0!D4g z1_fHP<`MAg8E(9HGRTjsq&%(TYKut=a<^UZxe&HE+I#MWrP1CCz?D4dFs5D%V?Xlc z3)J!Y1Ld$K&F_be75h#rb=Zn0(`lCVBfY0%!P?aI0{pHYo(qJ@j3FOMqE!uMgigLa z{W2Wp#KF5a_gZ9Dkhg`DwPa+MLbnQ+0d+u4;C&abZszUbs^aGB|H$q z`SsQug(MoiS*U30rh8GVfw+gTognQd!CV*tjP*1Sa?QlV3$^y*y#1O~Io}?m?6mvv z#EAXt%`pLyvZZaTnTn1HIiBXeE`u%I#8!gj!v7WNY_Tgd)T^r+t=+WgJP-#yohOr%lLNPW)UGB*+`hd0z!RTJ z{WYe2lh5oe`&9otBmrF5j^~aQW^Hh{0a`zJ=FAz=v_rox$9a)2%U~+>y@ya094dJ3 z&s3{Dm_hZUmn*aW@PqE|?jWmBlTIaMpuU0i+mTdd)c7|Zbvr+QRBmz~$Nf!+*|-{T?0<9@W&@arJ;@({GcnB>f)T;O33d zjfI{oxG2dD=eb`SvQRW$=|C>w@O>qumFm-vIR`rZ-rTH950Gou3^iiwLbd6Ew$!x>~CW_ox!inF*I@`GXOkCVC1djGed*0SF{> z2PbJ&$>W~YRUX3)lDOq#i^Us-YN!Y~+I^eTgmP(-2n$=5ShNvmPJr(IsRqvpue`3y*;>KOX-l)_|eHc5jWtz`}B*S|%&fD&HrNkkD5+#uE@?Q`#_CI3@ghuY+J;KR2g zoiEJ#sy*r9_xTzc>#on&=z-sK4+GxYIM4MIpVJanZ>!s2^7u+>1uOVUIr!Z7eGFnz z(o}D;H)gdH(9!|!PXY$HlQC+f2bZtoP^Rq!`o`fKB~JFQ29f~r)Z1{cjo2<*N1Si|pSK!+Ot zYVh*IhbRy#~7eV#M_`EpVcb*pZ)}JZ29&>*0Kakv8qJYHCH#=VgO(^~S_beh|>`H?t zV5#r9-@JoY^ji|T6zNH;O46x!U!t!2HwE}YT45nvv<+JnYkAIℜ`PiuV23!R_(6SwG~Ybr#C*|tj{5u ziSlmGmNhaB!cP|W81fT~UR^F;xX(squUYP0kjYGMmdgaa@-Wy@ZeA36P5uYxd9B{d^ z`|Tgp1hC3Zq z?cSN}+Pzz@;5@9rxjVx6H6?H=vkBehSWtv)y3c1Di#jKr`+ZgRb7W)5pRYBM6hXEV zf#MzgKDRHQ4?Mrxctii7ri%S=;@jj3yYw<1GAS9hPJbl6Rs zHjQ0qfJVw&wH)4?Q$Tn!xfBx{Su!|WzzN|JYc`@o!(26bORx`ge1>1+8UY2wV3bRnN_%*iD&yli%Gn>&)H4nhox0H5GPFU*CIIe(IO7 zA2UtWLXg8}Z}GJ0ixYo39hGxveKSXi>&A^f=~HRgJ0P<(y1Co4g#I6OJ>6LTse}8U zcADYo1~D=@aW((L-7~{3$R)>cD#a37#aC&xw738E`nGvrqZ2!0DE88zg@QLF!0T0^sB>>JIB?WaT%mjar@qj(*PV= zKY#q_04HO0?@;gkF<^3OC~==|V#RIUeX}lPfXijNRYL>uOmyN6RCb?G$nE3?J2^28ZpYDNVQlX)Nck9{i7eWcScqtmdH>j!fqi9@kFms?VLR#w(>JmPF=389uvY-rypPLb=&a_slSawX+n-0N`~6^kC5I` z2;wOft z{~@QE(66=OefIr1bPt##dv;?kOB&2adx$8gwuQ z!(XPP8Lw}=Bfav*)Bc}>vQWg5;eY)5ijLg0>{4|)62fL|j%e>bfg0VA3vOt?mUUk@ z1P9513$#Tnl(}y3Ek7vsNXj2m#t1}Ax<<4&1Hj3U3ZFZ8uI9Lxx!?QV>$bqYq-TQ$ z4XmWd=Bp!Opo@*^zLgRc<$2tC%9PATT^x>ql7zDXJ%2lXT}xIhZy%7{qHPLtj3KL@ z&7pmCNNK8n2Po9Wn=*)bT4bQ=+}F%a{f-qH;Zm~Mh? zQq+-d2jcn;*|W%CreaKVoYZ@~(%e|R76$^Tnokl$7Jycy6JmlgOMdm9%wRP3g)#4K zy-pj&!()Q5fJ3aavzy6=n&Iw_IXJc_f?5PW>0^Kkgr`N~e)EE2m*Nc9X;TXJ)t#n; z_k@#1MCv4{1ka*1+iGIz&FK57UO*=$5B$i&pgVpY-!uOV>+5u&_oZ6OO9!@Jr=z0CJ;nzh9kiPo_9aLrWBjP93e)&b2Jab6_%j{YL-EiR zuv|cvzI`X7`5oAAo1cG3{4=A1L;Kkj36?rzJm}*&f=;^M_m2(=?UNEzxXh-8Z3~-^ zUZ|XWogWFW8jHOhgeS4?SESLc51zeK{rwwVc~uyQ%&JK*#&!h zu1}7yoR&AXI#-PvMEJ5wt{l*2&$uy@z7Xbc*1f2KZH2bVpzM`E7a+M@AoesuELW>KIGzBJm4Ob>A!GlBzO&!-nl^)_C|eHVDjfSe@`OPV^D4Q}5w^g=c# zzICfshpu<1qX_Kfgj~?dd*6*0y%n(!7Fm?KESpg9yt7(qb)BwNm|`rKf)uPu$`?dT zKPoFvLAP95{mwT z(PUVmNvjQLZ-i?DIq6R3HDe3(8k($be=@b(@n=4eg_7uuNLZDhXEM<%7+(u2)OB|V z0^dKqIK#qC@4vUYAUN*gLHicQTDS=0OGOe+w&EE~qj;Y|gN(sHAF*hGAzP}> zIP7=6hs&Lm<+jiUvfn>@S&@Ws(5J?xH@wnY zP*4i+yCshmU1tMQ@wn3oy0qhJWO1opzuUR)l|rQP0}PTFLxyt7X{O7wz_tJvZCwZ7{=)0bptqm5_m@Ry4nH%d=9$aisgU)LrI}fbnS6u z-0Q6l`}w27VrR&k7JaQj-~aooYi6I}b^qsI6t}dvs}zb&mNH_BR*j#Gw<;#Q{zEFai}5tYlN>{{FU!XeWnXn*OoI&c#`o9w6%{XOGL6F96Bj zPX{EY4EAi*u3eZU`FJ_lf4oraHDIA`(S85fG$2Z4ElwajdY`)$4fz=FWA2vJ=bz3; z#S;6+&!4@veB^QxxJ1zF@{&w_tVgi&yKOw$rt;;E=JD{Mk?xHYXQQNw;=IM*N_YXh z>`01dkhl8zrSQ=>6N8+DNxQ?Yojc#|YPaB)k7gJA-dTKg1X8yz2(Th(8cas8NTZ?b z)e&SSx;5cZ?w)wrn14*X6Er%7tfIGUVTvZ8`oRG*Nt&q)r$J7Z&Pyq?2R@kLaPS;1 zD%|7?P~dqmJg1G z?`JID$#X2GbEh#^Tk_-&<4KI)(w=}5NU{Z14+dhmAPNkClOqMR9Ntsf<6C;zh%Dn^ zN{UlbA#g(o(uhdF1L~Qvu6J`sappFw1%4|pAEGgw30s5me<&2Oq2w_VH3j#U`iGsI zwC2@VpEFgjGe386EITL?dIMu_9qN9$&I~;nN44MN3}di^O%9h%UHJ?EC=%k^GZ>=i z9id$-V3N&z9{P3X)}87o6rb;rcX)!$AGElTNQ8Qa8OT5lX^b0mw>s)LKA*Yv$y5Z( zVM;>_Nf=L2@aF0?RCHY`N$&V)e(8cZBn!&Ex*ON7HNwVhasz!DCnDYEk+;UJ@1Rma z;ps=`0wr!a$6W(Qp7E}&1(`;K6UldmJJtO5oz|I$Qmfqx&dX+553oSOlkovZ4(+?u zgCKoWl!vc5hWqqPGRQ3WVQvd!!g-u+ewiQk!5~p3n|P74Ir2-wX*IINloLjDs9WXerG_G1*zwd9oM zzWWoYx3f6itrTZe>+kE^mDpo;)daY2rddg8T^{ z=tzQld(E+uIf8}7*_OtOGZp7aOXV=L_Q-kS-X7t6Zeoj<^XqLWxLo(|&*2Q2P&&o6 zr;;D`eA76Jt!d?^U?c{4WC9872ExLfsAQN%0YJou!{8_2e>sJ@Uux^VbEYGtwJ*<_ zo#{AQj8np&ET_Q@a6;tPNvN4XDo?Ssd*|nH=Z#6reQhgu-2K5NHQ54q)9XnoF@G{H0lL{M|Z>vAX~RS zC*Y4&#*@i8v@8W&qZl}jnaR-t-Rr`+@Sg{3Juj>>rNlFUu#ul4EC;RXDzp$7EN;bf z?$Ts|`C-01X&6Qk)|`i=bE&w}&{Y)9dr@hs}#{^EgX%LUPkLVw+3!cG(3K$JSGt(=`;BT2gw9DpalmDK z^8mtV_~?<@>MPQLf&JPnI5wtEHP9dyS$?V|wBHCI%nJykR9#L_oPkf9X}76TPoV*% zq!82%eDz4le5$U;pmOd^utwuNv3Rz`gAP|)N!)b~=1l?x`MK_vEkiuWZQ_%gJ76$* zd)C>-R(|vCB6t|?h*L3`dhF^WCA*1LIJ$LvvCG!hqXgNVP z^nDUnvs8PNZ&I(gy1II*rlzJ^5zIr^6X#BZKj~9lwbKb`X!1>{>rlrInHlV5(bfog zM?pAX#KoG5&Q%ml%c&EH9FIYaGVGY9qp)aNdl`)U29JzElt%ug%`-Amrp^DwF0uO9 z#XEOvI9k$4<-ZE2sTdC0lBasHlaB2~`8DR#rahe8uUQktt>n%N>y8XbZr!C_hYnUh zM<%?Oi5nRt#kqTV|5WUiG9N{6W`OzFvD+mEeynV_VAH=e7iBTN10Clmu17TC}aHeC`I>ip3I#a7+fc5oO{Vi7ozLtaX2aMum3aigfI7LDcf&L!S^CY`oXerAWP;JAn zV?W@k+=e{8pc+{5mhDL3L~`4keS48e(|IyPg@jf}^^!X(iUZZeda&U|I$IsoO1Zb< zg5Cp!%tmyRX_wDQ@NZKg9$-+2cbNJC&NXDqi5O}&IVdHSU~(c5klQ~JT0{b~Bro(; z9tr$84P%7inZ=yKESodQp15~|P=6y-1l+{s z*P0ACg-V6>e#ADt7x*TMiX$ny5$DUTHf9lXskr8*Pye}5JDrcyv!j6gk*tgF$HSfX z53t!-{nL(-RNA@c5#YWBADEo6l__5^fFp8s!8tWYv{O+DP<n`K9cIgxF9=dCMbU0BpO+#)bncU1DdS^1*O=My(<%LW z8H^_WIMNOA?%9R7xm7rb*;RgCocZY<3A4JYs;Ysg#%w@lN-hWKP5T|a1uS+{n81M7 zbI5eZ# zrjMgOvG;mXVY7g95n<$j{_>5Z;Qgj^*uJ7wgy|^bC8>6h+!pR-zif6q+-7nS8v`IT zZMZ|IXhez;oN__ED>B_Zwa~m#MNK1G>a2g`M(ry<77pePa!$$t2_Z>`mlHe2q_uaO zc6mt0nz_`mKC0A+gZB(3|J$?y*7i1Ta-Of>y^9b8pAtj+Xqy8g-?X^U<%PsrMTf=V z+B>J@b(25G-X)4o33B2_Y2A_US5UqFxl4Oro_w-4#rpFf{x2)F|GsvRd+k;C=imOn_ifHKm(XV|he#Dw7nx^o61*!{ zt~52H)oWhE9o@3_vrGKTQ6^5)%%*0nTGc{6%en(*Q+wBbDCZDfi0xN2| z@cP4t19H1B#@9~F=O|tT&zrpNIi)IZcX9^%E$~$*`2k&r{VBW@l@FUVZhU|iQ96mV zP$xNlzS^k6<=sI+3u8-VG#vM&pdH*(J+&8}n%)giZrP@h@*_5xPG2L{q~cd`!iwCC zB{{!e?dhUwKsFcmq`mz21O7&TzM$(N0aggdFw+s-&bQM>k}Q$NCSgkr>SaQ)z^Qw5 z$5*mX=c509Zr4L{4nX|MaaG^2;r6r`uB9$isW+oeBfXN=Ds?E1=(WQ2uc$%;FRCy^ zKui0sZKD2Z`90Z6Myr1Qs7LJMARjJx)>C6eY_+RAoML1>J{rHEt6EW!Nfj3}Y5H4A zP*)e1bS`5~)&NvU`jQYpIho)0k_~^Q`=)CWWY~&yzDGh8^d`5P88t^U_XvEvpGpp% zkX#9#sxnZBG&KkBfs9TzLNq7#*8RaWn{>C}RYuD`0eLo1%<}Tz`kNmXKewP{f=h3y z$w>ITq@MSb+oov+Nyjl&@n4^}zIJVn{;4_@&P_4G2;Ev$It>4~iKMI#);Ixs85kqH z-w-CnzG9YsZGp;{nt}yMFp+x?BKhH(8hGSMo`{m-v zPDfHj6kZ3=7=3@llBmZBa7<2`2@B5#fE)S?UU{zaP z*Z!G=&@c7-n)!fM5yg0fyhi&1S#b*r(%9<1uizhNyY6e#2^ppZ&{3PG01x-E@P;h# zF(gX=_m^{nIZ-c%FTS?Zq=^S;m#drGWAwc+Da&WverS4?kU(jOX+NSZq;crD0Pq?bySL? z*!mOm(SY+Qytp$2aJq2u;``)^=b7xPL++9#5>RI*+uMm8i#R)j<}xhAb6bByisQpC zAt@=s7nJ7tmASD!Q!61dB1os?4Ix=Ffy!{iW{CW)L)G^G^>Jn$AZ~af3#eyyeoeQ` zY7`>>_gWRDxJn$o=)$!E+>9^|aecnISOsc)E^NvL>DQ$mRPA8?|3uEB`h^`ogJdxt z{b0h%r?c4mvmG2R$>1f*P7ScG?D{FO!L&Dqf%-P((f#=^~qo0d@h5`Mhkn-nHZR`XXnm7I2s2A28y^#3QdD| z=_mKdkwP)1K79C-)=tRsP*hHa@q>K{!at;+fygqX2o(^DlzJUJVtEH74iIgpse8j= z4m^W7e*zszWh!76fu9dkd)?Q59E)@FzwIpKx0I1_ElsLHDeNZ%7d6&lq->0|4`(SR zd2+{+7QaR-2hHf8x}QcS(tS+J2io9leL2{viL*=U1OrK!Ku!}VM9eO2N`HJAcAore z5_QAiHXVD0i1s40OXbNl|U!_xKKtu_sAQZiZd zBhb!;jE@e2**EYE$WLi*8%Ti!Cv61Ax|GAy)@JqUGl)=X6G7V&&+t)NT90i%8X)mF zsLE>Ghwb1f1Qbzf$-j_SpDR1hEnNv1&X6yDy@HP5$vN?!4> zkcuMj(&jkAw44fqL8a&3zIt_VMVIEyo1?aTPrHJU3c|6FMWU3tot-^2xANv0|Lpc1 z+|MN>%mi3tmW(Dy`z`!4WR-&!H zHe~nhMFWGC&9#(!N^!@&KZy1s=(P5Y{}ZvdO1a51=3*iVX{*fR&;MBE%#b8{#8i2Ma|6 z3nF$wW0WFTC@MBoupkH)ii*JdTQ_L3&pz)yXPooL_m0m!c817P@B6y0wbq<-&9$93 z&tS?Xf_p1?feg5MrzL}5>-?S&+;jO9DO;wD%n2@dn_It*!=>=Pt3cKF?VEMvQ2*I5q!2SFGofjTv$<51iXwpMCY13YVdv!ec zjzHJBW5>(?yzHzO-HY=&!-&zn6a^8}FRTPr-;L&MAJT-WXuC%p-?-)EdRdyq=<}R_ znRE*4JM?TuV&N@}38$|9z<#ut$PdD%I(g?GJfTtxdruu~`Vx|10*wQ#{BBr8SbEa3 zA>oS_p?uuRw-GwP#r_5EAX%Y6}?sJ|AX(z9U*)lV|3; zte{Tdi03FU#}+^($}f@8i(fMUwrX-hIbgOxoCf{-W5Z-Da!Ab|Yy+XfG%7(B0l;q1 z;vI{^S1Mttv^l57lT9MOPm4Otefv;WblAs7OMcB6TSJuXLHsh8fMov~`SMK~GDSiu z50MrC)xn9;KVro*DSGAzO+~Y3sXM|s8V+y@zrdNFJq#ca!dG*|^(L!Q@6Yr7X!6Y3 z3538a^m2*>Ay$@|2px#hF&cE2KW!nBRJ-WER%%s)*L91~Rm-RpmRRZ?4t#AjsP2y%cPb>V`lR04 ze?B=Ykqrm}Yx&b_HA(A{-0Oe-F6umYgS~DtwZ-RuH0pk#wZMusfBr*}81uZfFTK|2 z#pB3~XU}@OaJ~yS2OG=Dn%iKWqS%bK8}RAbHTe~SQ5ZXs8%u#>>B)wH8TMB_Kwtq- zzKQZI)I!OSl_&b8a%hqbZsJy;ATCxtqqvvGm;rGj8zgNYi_71&^g;AL=V)Yd_&1uv zu5yVF+Fe`v@=j(!f7;$*CH`6z4;nGAs~efew~_P*#J z!(m~Ex>;Wyd^CES*TE8S>bw_Q?w8fE3zU2bUXV+rImYn_Qqe_Sj#y1{PLt3t2f4^$ zL8npA@bu9$NyLQda09A}DMceGeIiQAx)Z@h}pr7BQBF!CSAIu zo#n!+`9Ke@X#(_O(xTR+l~&xOVkW9sPDYU-3Qj_|hzA@7J2V4PqLDIeS&t<2Avdu{ z=yR`zem#F2%Jrx|4Pu}&w_RDKOvM0C8J5-mKn z-&jCs6TMRQCpfJor){@@Z(gN1WXGM8ELYUGU{caDhBC#^F_-|cHBP&C{3 z??i8i|fw`zRc5ej$;AfS@d#UXT5bVABhx?>mYK3h{9BN!m$8olro~m z4Na>RWFV}fVD*OZ#Ydx)?s%8Lr8u%q!*a0kJ-`#2pB8h0lu~<)lx3!LfSuPHPKxk% z#krzm8!-;O2XMi_`#B2{Rike^rxGdow#rN25V6D4d6~Y89vTU52HwB8lRW zUA_H9rGd`WA`ssf%ua>)v4YVMcshIe+pn)IB6`{FD_CyW*8t-^b40A@vvw0(9l><7 z2Cs@0FAq?|H?LlqbDsjpNCOy{0P$kY{b>A;fB2>AJl3&`?$nZ?6#?a#Cmjp*dz=OP ztCIxigV;@@nu^B)^!NvQjND)tg%bgwK4vvQ>u6gNee#t)t$6#1T-4O-demwuz9{zg zbGmYArV0QIr4+eknul8I?c?4M1oYQN0+t>y`ss@g;4{FM8rKc28aABevGO!b5SU^h zENsA@k9@&bt{eb6StMkJzD6t0 z)=Q`4>$6M3Nz@#UlL5=#gsAeQ^ANVZ_wo(ISzQh*bZNTN&Yq-n4581QnRCk>N~KxJ zO`3-+A}m6MbC)s-x1;9CGyVwe}(@(u7Qywrm8d4fDBk z=dL=u`QF-Qs5yi$B?tvLW_QZouTN=6qn%oMnigL=nJa!lVKeu^dc&#hXI0AKciR2+ zUe54fSKIpTJ4e*_t@HCx;ecg9V7JScyHdO{mL(uUC6BVgZseyXVYS-Ur>9_kX<-qJ z#T>^Dk4XY|3}%+3iN!qAVs5Bl0fM|myd?7y;7t|5dD5cS%;BdK?!!SlOf3-+3;9U% ztF8E2X7TLTNnU7e&5(!F0CAyp_fna=<>nUTf1x zOKaIV_3)uq{&m~hoP4-7@#4jwJ2m?$oCfol6yH7;1)lealc~Pg6~tmQOVT-9IwB>Z zRuqo-mmWgyF`9SVd}A&cWrZt**^_I*wsi2%I(=(8fk9A}PWx_1-^aFkA}BJlrWM*) zu#>{TxcqaJ*XpZmX;a3ml345rO&55+iP+NM!V^$jq&ts$-TXxr z7G*Cub0@&*<`xznl8&A1xT*Zq`(SdZuzFb4fTrk|FtXbcnxGh#>i~%%tlb1g&(ZWL zy*tSFjWtk&d&oJTBcV2#FqmsbrW>xRICJe3&lTY*J2?Uwv>+rxJMHoOo2pCWmbC`$ z`le<7-DJ%>d}+7S+~3mC5Ue)Ex4a~Oj>n@=zKZ1Q4mMxpWjCE8A`4NNaf>y@kWmtG z@>;(I7U+F25{hWbsgoz|bcLsTxEQD>lPfbOp}NOrs|1)BL)Y(Z6V(gGw*eoKnn-XJ z#?}-RYz-pC&ef;*NMprmtEJU~4Y=)*6A{orJ(&lB=@`V^nl!AMk{0=s#tKg)j56;; z1=BJnAw|kla`>>)g!%yv?~!oCg;UO>gas8!#0<+i zK}BGnK%SEq7f;BL26J)dlU+%oB6GIV;twa$bz&hk83Ge46cesf8=5sWIrnJO$__qu zlvb(p+FW8;L`eHv+%)>hilNeX_%^J+cFp(R(H)H!1SfWy71XFzt8O|D?JSHvzwmc| z$?+M|g0SKZB=FCkXsJj!n4G zV}ph9xg!l*wVLno^Q885>#lY!IS0s-6%9_X{r!+be5CDkmjOlXl0^notjFz>qL>Eu z46Lw9I7H_>Kka^8)NCL6QjzGsmjdt$in=6{iY~a&4KmX1H226qkn%d(Z zpYf75-Y2}q4)jbQjnF!>_g?Yax>~Kv-Z^b!HAPgO?cJ<0JG%#=I_%uZ`i*t{A1#(n zjTTPK?MmL!@9+bT(wN5C&=~vZ@!D$mE27IWG!Jc4vInB^^MfZISdgiXc}W_P&J=Dt zeS2x*u6XNCGNX&v0pMJK;Ad5phUiSb76D*bz$3y-n-!Ox?ndHfjW!gA4j znTf^q=;j+P+YLeUI9Gopl3wsc1hCo_XI9TL_JTv7f3u7h3UCBi(JN_Q6j~3n{7>4x z%PI0OyQklFagT%qN=|93q&u)>c|RUt_j454s)k@1`tHL_Uel%kF$y~LF_(<= zT6YqZnBzc4 z!mfQ?=V#hNgX{s>-zOB@PEYTnUQdW+_+2NY>o&rPvBl(kkt3M~VN#VXxRvNko_sed zTi`!GHAOY_hG>BEl18Z6;PIdu_f@Z1lNQ`C=rExSEg3)>AeSE8L>kGdk%*x6rvYGX zQI!;BXaByG^VMhPo6HHel44U!h;g9X>&GKFjm?l2`*W#S*)mE%x&qZ74uSm$>v^2% zWTNZf3f}}N0d;9hdIZgB0cziS@+v(;x;l50kU-5%;Z}}Ag-OEHiSiHxFcAcZ)S_Sc zz-FVr{rSe+zLOsR&NXJ0N_WDv74K)eJ-;q?ofh_&Om> zd+gm~4*Njua~F-&dX>HSZ1P)$aq_7A?OX zf`6P|@=_Ni1l>4!=oJNdr!ja}XnFF6><(pR=_iIa_o-U-HvNvU_gZbgq+#aB?S=xL zh-^<}HAImY;w-@Fk%?KTd|eWWITy!09Ad{N(x~W-qRl;bkv{j^=hVLtdEq6;9uYd} zr;GyThtf^De0e0yU<(815V17iXzi3|`ES9kpdA9u`LwjB=4!K)%#G%$p)65R#efFe z&6%S}_{j-8o&-#=id(+SB>Z*Cl?_G9FL@6GS6()+4fWA+J4~5+kNv9j0BqTJGB)%e)m$F3AWFpql zC5}gWKPag;$kY8vjSc@3wQAeSd3-gQtUTj<`mcWl?f$2#j2Pn!uQ40?K)MK= zG~e+ID^Hp|QOb|y@KvDx4OV}CM!OD?jam*t^t=HUiP4Q=$1x2wM~}WsA{$v{9u4Et zOdPd|OrYkE&%BudD|0xEbf6oXH>jhwAgdHz0oqSI`JJSq17ZmyB-(V4f~Dq9)*svh zil0a_C7OuA=V z!bdcbFW4W#_SvM6 zisU|eG-JbGFh0A61DzW5rtM86nPRLP zilBV-2WnyJg7g>b(tRpgD1=n`)OjDWX*U)*9Cv>_IFTM-S$f|^bDTF>-a%{7W%w!i z%lR1Szy(r6%l+n}3Fj=ZxD1v(w-upZs^2w7Yp@hEecW{&5Pn{sq2@X@5 ztqECMezsBexy7vLl{vH+s(^klrW!G%qIfg--|UDz@=97}5rRttEbB5&FJe9v1uKu=0;D3gM@0Ih)YLHq^;^6i9g_=duKKCrSn}ciFk=AL0_$izgsQi++yP#IMsRG1 z!mo_3yRK==U@Ow*F@sRkloodbQqt)Wifqh% zGU7m9birp7yg_#7QZB94xKy? z0bWY(5zt5ZjoonH?h(qplE-O1%o|x9Xr}8yFJ7~#VvmaaX~$9Jh3)Fr2$dr9vSPx; z-2!Cs1IcSZYYV zwVW!F3vr-7|8%O5B08ZDpzFn-^GPrbCt6GlycA`lPeP4GZQ>X%M&s|l-^EO8`>BssY`T3M5R7lj&NsF#ca5Y_;XF%jvZU-+cjfyS1jgYwXB87>u< z^C*Skdg$G8T07!A`iOfrYiIv6dX`2 zYpLyonCQ!GA^5}@YW$JuLoLa=(?!lA-NlfZ{=CMVV6VWuUScQ=rFDRgEs;*Z*r9IU zyrHD5Y|Dk|qhIH6tCC9V|EJofolI(ofGkObqa`Srb7q&M_RHS7Hp+}ZdJ}bEE}grC zRy@|GN;xc6LN|IS_vm3S4p09bJ0=GnVtmNNn3?-jTxCgq(&Algg#kSVCgToBjmju}i+%kz0Cy;-9Ud+e+U7KK&}SF!_h5 zi|rLf6Y)&IP6-6kg-^;X)FipGgF>5wuV5LJ#x=>1>>d_Jgw4{xNDU^J4_L&k?3NNX zr?*-BRVjf}`VrvgCZ)v}Eg;uEE-56vhKXp@GUkxXinNC6L9jX!UpAj?y(}SjLP@Px*^125Y;$>`&E#dZ z&ZleltB6|2T_dptMopnzh&RYg^mX7tC6BZo1aj|hRL+Qi~!SW8jF<{ z=dA<1Gzx<5CHsX9C@x8+FPU3mHF4rrTAfUhQ^`e?NzrO4-XreM?W+rWgzhRV)ct!| zKh+YfBZ=pfL$!AiBqXlZ8q9SL0T8JSvct_pnm2;YIUY1wRFtfJ2be(*fFC!)eapla zvH*{686UKW`7H<8V3E>5#f`tkp;}Bfu>I0bX1a5I$YeV*tLV--&l3biQq?`&fsT`l z8+`qBfd9uc%BpmY-r-X#Oyz{4F@ibra2#%7OGyE{CvPPbZ@DG(B$kI!VS{ z+4uEYXYf~O&3*fnfx)GL)3quGI;R(zm<+n#*(T!Qy9;&ONM3RaR7XdR!$T&A+`Iwa zMB86iuXaKXa&vS-Zqgb#X8Ka5sV-fo)KhDSUYaQ#qPrpFasLzC?qhg-`=sSFJ!z=t zR1P?_GAI3L#y3|Aw5|o+d-Jxg`<){h-@KZaG$y`Bx7duU-WR}Ta%nu4elw~;0TRut z*ndKfB7HTRP#ra$`$Tq)K<=w<#q4-a{!GG=8|I@iq3RB3=d$Jm=K=K{9XUW6UKIA& z(9;@oE0H$|4$S$2y@Ga@$s1y|RO!!Dx15<(v`*b+aI_h8qC?)!;y;EMD9LEsIXiha zspQ(e25I;U;>84jAn9>A_a?cE^_5O%47-`UJd*OwtAeC%L%e;$zczWPbuI8-ae3Mu zvwp2Xr$nO6!W2d*Gdew=nu9AMMjN!j?}1vhdociqdn_oIHv|ku zvkoBcl1GnuxyGz@$Fg5$%)FOl)+ww>@a=@@Jvlad-G6Q7tL%)b#w7YokBNrFt5?U0L*!RpDGWPL!}VAHi36(yS&$ zoG4_{0>u%S5Rdn}B7X-Ql;-@SMfS=?Z;Hv%UR53v0}HBHseNNs?t}jaj*3Mg&~-x$ z7%5_s>rvf!GFpbPoIh61GVVU*-O@Grz|juW95gQ2>>l)Qb>*7{bJWs0wQhW%l^6D1 z(q-)9qjewaX<--GXPSy5$2fkTqcl+8DS{fwy>Zx3a>i_=UC2 z@25ntQz=}0aVm&_tkk*e-)*urfSpZ?PVn+0_w~B85mMMP3U4JKcAMAc76&RC4r%?2 z1dz%^KC6ULKmu{PLp@+?5<`yCXpIrLj${O+^Sn zB*_%F2ENkx!_%e)dzEg>ibaDNC9Zk=QfzevbyOp}`xST&tHFBdE7q^!IsbS0oiNun z9~ns@>gzGs3Xz3OV3R1GdMpz_^jokA-Cm8Cd*ZW`z}2d8k7Ip^_gHPYmlF>@AKwfH+o z8wvU}3M$8YCl`xVr{*1?nD|xc}9TNG=#bG!6nTcqI$n=QCNLgDx>&ca(S(10qe~w7_dxHjjCx6jN zd+gY;5Q}5OF8`|{Pa1%r|1u<@*{sO(2Du3l;YFS;LcgO6pGv(>DL?y8kW!YBqtE^} z0UKpgfcG>|I+8p?rv?iC;a+b!FcL_DzbMUFy>lVIG?n+mWQ34pRmkU{JNCD{svN$4 z-$xF5bel0i2s!8S{p&SoK%Y)PZf>r>>gx2Kb4kB%j#ggSufEK2kd&NsqZm(xt#_cd zb=8t4O3FYu+G(t`-}u!53LJ+0U~o1(9hJR_J})(~8o)E^4z9k8n4C{u!$;gATrG-b zY3yT3T+1t#tElC`eecz--Gvqxl^=siK zRs!@92>^BlvtriRL1s5@-aLZ5s|L?B>fVdSXo)dmpaqQ6m}W=}*;Q?bIR?9Q<-@hD z4H_vX&Da29$a7biS+VT0e^z6J!PT7n>FP!!o`L-zBjt)T6hPeU99#X7hw?IY>edyv zM>eN=Z7)i+4=_*%s}b3@TROwCi=r*vLAfSJ1?h&RO2lD0vl&7PBEzIl?;|7i#2Qnv zc2OXLwwt3Y7Ks6!q)d(}Y{&6<$bZ5xwME0Dc%+_vT_^H^yuo?P8Eyps=IVkE>(y<}aHo`yx#aJjO(Ye^I-h zpz~ts06e3uCcxf$_c{kqp$iBa-7ub?UzrhjGPf+=cJQwxm zmwlA^Xg1JK_LCExg+nEb-!vGCn8N9X)KTPnIggkISa>zE7?-w7m)>zWP)iv7b4>Fa@3gO7xDSt zyXS1^C(3YQ;gIlZnAc7N{>$m=hu1clwO+h33I=vr453U5E|BO zPS0kInl_zt@R+=kR`clrD`#;6C=LoBHXeDQTaL}X(nQb@kt5T+F0DuOYh0lty#~p5 zmiTmKde2-}r5f_-%8@co1H2Z-_REXCHO-Q{BLvvFw^s*Bt0(vP2__2%t(5mX#T2r+ zN)H*!pnYWHmfE5*z9G=FoT}_(fu}Ks=*sEk6sCLy<%w-~ObmJ0tVxr=#|F@nk~E9` z2gp^p^o#E@n}6+30)$j#L_J@Q6^FtLGy1kDK`7MGfpqS~@K1MZ5Dhj7=+Ze)+0lTGwhUEyxhHKF^U(B6$uHao3U>8S<$uyq)xJoa`!} z^`Q=06$5n|6u_J%O7DcbhfPRz8?TWw4P+LEO4d5aD{9iB_4Hms2fIByUu(puQCE_Z z6mvZ?W$DsMt+{CDDZ9GzvvLz-fYLzQ#ZSySMR|F@Y1Uo(SF#E{=w@)m60iA&k(Osj z^rS5Fu3Ro6znj%^%s@;)A};&mJ;U)3Q^Vb~glC!Vao`3; z(A>WGkaS#g1#>hEJq~~q&czn^=2&rI(efKxO;4djLq(hZv1Ln&?YsT%E{F$**or)cq_CaJ06y*Qc>ZAh%V zlwbp3g>|`%u+X98VBwKNq`Y!N+9XA*B$R@!VIq)m&T}aqEBcoPvq^!dG8fqaa%OQVw;`T}NUa&jEQS6ezh$a1aFs=6Y(-8I;262mQo>hBBg_Bp(;SfhGh z#S}<}Wsr|EhI~9wY+{C+d^nPPv11d804NQ&7-Oka5RhnSuD1CI>?>U+0`r4mDRyl8 z39yrdCV_&?oUEj%E$0wkFdY?_6gsUV>3b97VTcjw{}Pl-sxkt!h?8qRnEPem(B~u; zSoEKU)%1^WBD#NS45aB#dIkymP5`=anHI0-;f&*d@G_ zRb%8!zBx`KYnXfa92K6_0xZ7}j@x{Zf#P(-Z|Bq;YuC1ttl7VJNCpftQ4lx$W47NR zuj|6Lm8&8ytg)hQk>rxeK_W}A^pq(M6qZnsFE|3PSd5jM<@W<%F-M)fejbgA8NwtY zvaq}hU5AsleNBi+Ji>@9gjSIV7}Qr|9nwcCGPf9aH5j7X>BsUl9y-hpFl_0ye0kr+ zkB7s=w7awwf={gRv06G&9n-WzCIc(5AekYmmv!e(M}y##OLGXNrt{YX4>@&smhtNd z2U7f_FU5=p0=`k>#sP=&Ck|USY8D6i`L)jJ*V;UEaL&FzpnT$&iLHQ-t{@`F%9`!H z?7FU5{}vk@%=B5d^yx(TDQTuwbC>ec`D#r}jLXqyyU|=pdQFEDz2P z-Aii}+Tmdqb8W-9GXie5FFNft>9atrx}na(t@N9P9qR;XpUQjo?W?}0mQ?;&^Ra-j zQe(R6d*r2sClGeKhSYL95&!;EP>~owtGW{7D4BlAcH=@~~$MY!#u}^dO z^d{F~SD>n7vaku-2VDllNgR>34H29edK?pLdPf;v-eMhi`K!N3i}YjRbc~gR2H}*6 zMnvnVy&ah7M$cNmsTVQf0xz0?QN8r?gBf-~uPs?ANthhXF}?na8Xvn7`#hzOq-{Tv zr)qD7Tz7YO(Pl(ko)W*p@XF*xI8O}|w=}2vboozitrz{jFp=CT>h(}hxd}e^R#xj+ z+9s^Wo{e>qApuKLO#^RQpt{+{O(9abaZ#vQX%sm%f%(hhm(&OQ?9Is~2+*H(;H ztufBK&f=B4J)no}gYWZ{H+PPPet3=xp&uE4dH1*T4Y2WbxQPi^o5L|&`X!T z`e!hcnWXk`M}gH>FxNu=W-~jczBn=pKwK7udLz0xvB)Kh7k%%i+T;&nTQS`y`J(lt zexMiqfvOi|_zj|lbK8OPfsH&QhP7^ei0B;BBAxf*iroKPLy(C?kO~})Fis~P9biQxaO<4dr5fPr9)qX?3xT;{ z?qb5t&XRxvjX6!SIKU&90V9BnFfW5aC#FL zKDj7cQ2!Zhs$TV9TSpd(mXMP=ljgA~q1&jl!d$?9G(<;uhXn^+9yiQ26A2erLOJ<{OYZzoWevh4;_KA6Bu^cH&%I`-xt(VAZG^&tXF86ytI|u z@K41HT}e($&Z3q?mor!?UL?OH3#J?Frg4_CD0?TfD1|ymO47;P3Kxnf)N!(LbvRAT z@zih8ppNKvNIXXC3nC7~EZ(DT8s$USZ5^+>95nqFEz=eovM`2}ZkN!-|H|%X4KG#y z&m2w?m_q;GcR0zrBPM0gQ+2o>>5{Z?SW>P^_j*R@y9B}N=(D9L63pIUmxyHrsso1v zTS#Xn-d?05`_Z$guTqyjlSRxdj!mJ91q$G37S#c;k=Vwrjb<7Cm!y5a{62I?BEZ@Y zM?`=3Nx?EHXB=Tk1WARi zZiY!^;++V^PFhgV)yqryD*YD8NSL_k7rvmm<4TmWEnm4xbtc%Na0ke^ZF%k9hL8Fo z7!h+yUYt5BR#MTf8-&B6*vzL5izZU@1JStKNYW{F8`+~1bUs~DHvzeX<568H`FEAk zNX}M%MD}ebx){fcLe4Run>*!5n~(lVA_}dP|Z{ zDe(!jRQ9%kmxT^H{h+Z^-xjo!wyj&YE=b{nf_nlwVhZ5nqZL#7zbd%;#0JCHL(~hE zN<}c{B|u>&%ulMb@^P!yiX|P2nlLnE4vqu8E6)lHA&Nq@E9A+R95mn9D?qTK>(nr1 zZiZyEA2?EFu9~E}LP|@0JTI#^5bNzp5sc-w6dy2@Ji3q~F)jpgWEc@xusfBGy+d7> zeyY=dtmp!BcIY{El!q($`@0@86~h(SD8}HCq;NtclAD5#UniR@raAlRXB8p5Mgsj$ zR#@cSY<)9@Fs;%_3R<^@wcQB4Iq9-=!=b&Z8QjnjeA_OkP5gpMv+VjlxOvkWVlkbC z40zaS;F*m0`Qo;4yh6HIP%Ob+S-f1>?MbT*ie~?c!e6I)EiKjiQmV4D*&a{Z*dT=M z6*GPskxo4B#Uq4d#FC&OwmjVUam1G+bGs{-Jtg{1L*G+LEH`a6(`J%_mI@z#W$yAx zUcFlc-|hq$l@uqfhgz@to}s*F8m7a52f%?0kQ6vXZ9zktC=97)I)+;)2TsU^_6x_) zN4XgG0Zu_on0CZ0D=qkEa~2ENQeVZHW%j8j6qCpPxtJVsACuIl$qB(P8Zh67bXX=X zItpM!yMEUo#T~4zOV{3>ruR0!QVXCxD{Xe)hIhZNliz6f_rdL#fBfoV0_E6_OSAh{ zCJi==o9<`5+;|paawrMHTqi4EC)2Kwu=I2=WEHS9A@VAv9o!hRA7Cy~ua4n=)O_%S1XkLRC6#Vff9 zd>kf$c!T4#;{KH@SFj@j($c8tZwJ+Y7N52D^_pG|mx?Hj zIN9xAD;Lc6l~~cUt{9$lfSnlvH*w|KwQH^`*wB)D;M8$dyLyWT0KIb>iKA#oy5j7v z4TJ%axI(@dwaFLWL(EfFKKB|QDkW0GxF4GtdruZ>1DmjlJ z1*MJJiN1J0Rja|vZ)XIa47a0k$aKURiVI0PBxMyz6C#wk0A3<75G5CM!S<p8M5!A8Z&OZGMqxL&<_Y1HikZ;A52#I}um zN99l3GbHR2_xxC!pk@pfVa~SPypwL@)7FQkM8K8VJM1JZd^WnXzp}eyla?=Ez9ci! z2P0Gp53VH6G}7`b6V3rcM>|*iPp2#>1+oeUt^<66@A=#I8n#&SYSn@%|U z*K)dnt!;LLmwFaR9k?XA_m!2cyra?Bqsgs(0UU10{0xK$qUNPnSDKy$tmVD2{Z7oz zwxtmHgd@+nPQI5x?}Oy55?g3bGskFf*{tQO<(>Tc2q&i! zDvtE{y4jG55a|X3-$0sS%@WEaz}Rh3DHTWCfg~0CDMMw30h~Zb@JZ2ViFt}7vE29Y zWgFxg!TPH2pu^ImA6ocDKpZpY2dSyuvL4UmtM>jZk3IOXbH$l0!ni~kC{@D4=8%d) zwI|XCMTidjF`HisAtfChFEGayCgt2!g3#vq(mbwGBPO>3;@iA)>C&ZJup+Qtpd``G zdZh45ZD}+%ODrrgj6#FF)o}cCyq%<>hJrhdHQvXoqS}WNyOgLfjpi-)oQP4 znN-#_xWW0D?jTVbl%V;lXCz!5)MA+vcBlfL<~iAkRLpWT&P$=-FJ^La7d$3KgOVzp+5~UJE#Q`f%|a5iWarL5n$=i67BDGLM-7lZO5=iYsy{v5Ivw|YAH1M| z17mG3WZM~3YpHjZdwGdtlxRFGbEGan77Y@CF_Pw31x31SL(ql&uy&C!AmV_JLD?}z z>52Wd2ZOh|R#%-5n=cVmwyOJ{?JeQnEKNt=LuY^}Sm|h^UbdShZD_Zqd5!@Dw!r?z zEjRHy+$U^=2d0f8gyJn38@JSG5rF)YIca}y6P>@BlMWhFcqTkyJpv?nC#My zY589F`K^O%E?Zw}+-RwbPO4YC{4S?{KG`<0qWt2zK0#N9KKgrIzh^n8uevQBT=pz( zNYM}Z%a@n)*Ug=pZqfAkHV`QUQ;Z&0ki*9^Zc&>V4#WoQ3OuJ|dE9&Z!yaV|`gf8N ze94mIGK~R@e2167y068b243wjBgC$C*3%oW;4HLzjQT}M1VY2C>iNd&GpXe5;P#S!#GFB12h#` zgeWSqdB>X$dOr~|f956g?PI5%D+|=sYTf-w-rKtSM_=7FwnzRywECMb+up+@CvT+M zSvMfEGMVZdy!D)CDZ|5p=!JRy+al#F1&qG%V8iiJhm^!b`C_RRt%{xiC|!qTk&%n2 zO5<+#5+fK5jTNyogOzXeWWz9pRAs%V87fj@gxPnCm#X zSm(gX@E>fKJ!yPqrqHx1ot{0fA9^d_c=w>iqc6N>X3BMr2IMI{Sv6N5)8z6@87V`| zmMGKVcjf9>wdc3iNU>!Am6#FS?bD}^9@muaWCz@Qo|e*z{{pGbuiYpiGFa|R=L=T1#5AOkEu7I(zv4w}nX=P1>JVKe}zjf<=F@K=} z{)LC_^4GsI5$G_M2IB}0-Cgg|1f#gbP;3bnj^~WDwaKVb(~g@5e|$9zC~0@2LM8WV zd8J+LQGh)!$#|s9!&T7@Yo@m=Hn6B}wW)knx1!{*_jvYW%kLEWd$;B*ppHqCwj*1M zMNdXuCsS3a?-dY3TdL5wZ=n8?z_R746C4k{W24inL3YzGAsJ{%L@Pp*c{Vs;zezxv(ezsA3@_C<6gMJv$tf<6s~ zX^Im&Y@7?0tqGwt8qi+p7OH-V_8;wPCjdt52NHN%Mh+-E<@<9$UShSvP}k(|9{u^^ zo!fDVY0$ESripQ9eyGE`!SWn(?*H)NL9!_AWeRc;&L1@CMDbmXHL}16*j6Bfi~i+P z$J2QyCZlf~uBT^Ap!iIRcBw8=+*^zsxvKKxvmpcj<qa%{K-ZRp-Wds7GYvcZUU2A<~W8p@$hvf-Gh9dK%h3Liz?z;4sQ5%UqPFx@0 z-CI@fpif`LZ>0>XlpQMOrf5tusHnS>SQs*KNKwk5on_ZdKq<5MUoJZ4kQRKMCyeua zdy+4N+0a3%Y=eKf%p~tkBW%ecjSUoePxZCYPGbE)KXuX3QD2I7fd$taLCl6fu zFYiS)WkoG59a+>``CNR91q%*|`GEK#GIM~mHUKw?Plvzq7XJTeI;sKLZDX}&Ml%^T zl)3m~Pgepu?B1kfJ4`#$2W8NHc=gAT|1McH`a(1rVef6GzRi@>Nco-5mj)pZw`w$x zf1B~=WBYIAbXBTu<=y^D?W~ zw^@_VFJ`3UGWx<|_BecA;leFqO^UZ-?QH{oQ zSfxO($}b-pv2Ek3D$8Y~K5)QM1O4L)jBAWLy@_1n9g% zxF0qpeYTm>a(=GEkAJ?P`4=rM7MZHmy{f1r50cE@E|}AM+ke<%ek*6KSP|D@Rjtn- z@Zt1tvXn6()&Fw-;+kqor~1b;2A~8SLyaSXlkC!}%Dt-&`|nepY&DW=Zvg@$G4S5K zDN@702U<~wt#w_&XGIb*<@sx%99nNIHwwUH0Fzng$h*g5@YM(dBtaIMr&b(_WDH#+k z#n-$mEs@|xAGq~X z7!{aCY_84FjCqu1%^6SDX3<7@z2r#eaf|#z*!THVF0)%H>Bz%&F2G~zrEvuxoZvSE z$v9;;?jwinG4~>fEm{jiP8X`bfc$9w@irj-H2vNn-C5__YSiD~=qC^Oo7Oj0L4&=0 z*5;0EYjWVr05zH5+c$4=A6vY6GITY{*&$rMNz6~DV7M|sS>ES|b?no}GHq9=>HG-D zqf^(;(a3GvX8Eh1@XtIzRxf7>V}RC4D~RZ{`PI+}Wrm&BlU!|aa z0})Zfv>qD$Cuwa5+f6O4L@&1~$!^!8W|w!(wkBf|z=pe06TpnLEnfXHV}@Bn``NRT z0D%hgt0Xrd*AiVjR#l^^WL*+Z%AGZz)<5LoNouDZFCnvl_0uiQ3`0k?UiQa#d*t_8 zy4zM7HWCZK@F0)m#}f&aj8Ns6ePc8$qb4`jTkTt&8~?;91%Kk(Ou3(0UHIbCrzyrD zq*Uu>B_$<29b76RBF2JrQ?q0j-Tw3`0^F3h)AA34E9wGmAP^Pcddkaf+qQ{DkwRAn zQ@NC*SCU~ck)h`C3?GaX?b2k*@|!{nwr;&3+QV#d^~PG8%&JvR7ISMMtlG3WOv|L0 z5%7m@LTd{POdZ}7+bea zok?d|fD?Fr0C^gizKrxh*9m2C76?#gG>8>YlV>3BBXE5XMTtxs5UBtp_N}57Wdp_H zhj0}0Je~wja316pf8IdR$}lvDH10si!avb9_qG9X51-R=#1!Q6QWH~CD`)4UvQ2;y z6771cHG}c;6J}eiCZO_DDrAi#*%DzrM2;k2DhEJc@Xbt_lM2URU~qp~VqXq4s?tqN z7n8x3DBfr;#S;iGBIe~;W9>W6$jB(E^3RCoQ`bCVi&ZyYg`}a9vhS{MoJOiGe$spz zX*_>gDzy@?OaSTA6yJNWV$q(`(H3&-SYMZN1dFaHf<)0dWuAZGpq~sNxP)#3ca&FtGBniILa#%IVL-FT=OX{yz}F5#)G%6xNbdRgaXFmZ z-LoTymJ^IJ_6)20_8LcAKzw|aVR_Te@3L+-t2xg{cz^z+y9e#hsGUI7A#=2(#uU9N z+t`uj2fR3=5l2R$beTrAB>AvvUFf(XR*IgwcUhNrr|4L@^ujy7S7D)d4?79Dx>=KYseefxcpb+P$=)8zE!+qP^0`7qw4uiT%tyDjGU zf+064LdhBd)u+}lOEHQ&CB5sjzf;uR%11VDXZ2#`tQlx14-jyiNHxWh=zYY+--CLl zBqwX^-n$Yy&7#Y*7d99&B)R}WPh@yP0CE&oxZApR=C=@<>8f>oV5h5uM)ek`_XCefFpH(=AIak-UC z+nkwPeH7QfKkUz(xpS9BE$54wfM5;g;wLX2S*;A5hv(%5nl)>-Y7a+GGJ~Hu+VoGX z2CdL=IL~sc$j!A|H@LL;&mtxHXXD9>nAUUT5Z2%^6{!C=sbuxts|4*a*3fF+;bPB} zmmH7jBK&|n7De1$5V@`-r((Z4b6W;rg^-sXkJog{i1YGv@r3S{RSy~J$f^CbG$q9p zYv%EY=BRY9|D642v1z_i_vNaVasaoe@v9-YFgZhB7r z_c7va!!4gzuhFt`4)uL$7NjH6Pf{IPc!Z3W_F9pLqYsOugl zU=A&)p`~X&e z&j5-5yWc^wD25nBFQFqx47k;^fp9j2ZcEm`jD?1A+t0F+G(|Y_w}Un(Sv|-6G3;+e z-^evd_BE-&c#ItGLOY4_ov=$%gju%Kx6)3`EH}^yw`*@*oSidh2x7X9h+L(~n$;!o zlXc<%32m%aiv&69>eqk&`w^Qdv>DC*>b_!O<=?FaIA1+!+O#078`=Y6BQ{K$@_a|n z87q;Ki$J1g-t({1uokruWf$1tlV{K7EV!1MI-Y`ECNfE>mSJ1@rfHFDXj{xM6A z7@xI1H8ZcijrI3$q_*HfAM_=(kM=k+a}Uo`A;eg+)(d2=BHy+ujo>KRH)q#aGS?VS zMJ3&NVa+mY?MOngSw#q%-0E#vnfQgKdKE+?E>Rz^zg=^Tt~D$ZGWNd2z|eB$%)M~! zQ$u>eRXD{|urH+Bi5BOH>whzuP!lzFzbPhasQP_T-!{3``HSI$Qn!%t{JO&C&Cr^$ zbPx9gZri4lclG-N$FY1*y%#Qa1EFB3d22caUB|FwrrUDH=*lEZ@%)L3>M9Lw3JQgG zM2s_#rX5#iGHZ10(vZ77*8l~W*a@W!R}xi)AY(OFFT z#gd_D;LT5^1EBcV0}FS%?mX+LTaAK$_*hF<;}V{JrMfP9dL_;M^ti`gjgl34e%PUh zNtb^Qno01|EzV2$y?|*E1RWWn&g9gRyV;e;6Q2#B*rX*^69S0RsOHmd{CwGKd-l6? zC%hMVfxt6mOS`6!E6R9I7E>1k0~;2e&Jxp4;q2h;XR>aLmJO>``@S-#pB@pBP0ZyN z6G_fS1efb{)|Xi{fU7sb9_LRmRmfa&^E+;XXJ%(vzL z9q{{nt$3xn?mYF=SN?ak#hbWifMa>|VrW9ZHl^dk+6y4KYdv2Olnl|`unR;)2z0BV zs}6XgFQ0`DD*xLR@8m0ZuYFP$PppU82Wnq)b7)A-X$o{+=fgO__842_=1sVQqXDD*#}6g4=;s1e{mc zDWj~t4=4JDgoKRNP!~G7bw%Q&@!!7vMWUa)$?o5~teKJzfi6_0H0hS8LXzw(07YO& z^_VquW_S(i;R$}FwVTE)E4G_wKqA5HX64X{lkmPVvFLgQZrh` zI*v!u(lla-@_nod9!>PMv$r?aT(TIrl$2FqIqb`C+Lseyo?stC>3zG6-B+hSUdfay z>gQOSHWR4#8e5FLoTZdir?za4KWvgl=XdUtC*yC<7<5FF7`xVa!oZowUg`PaQXiI>$1`}QD>hl`~K9l*?*j^t#xTs&;)~CbM_AE zd~nd=k9Mk_-j>IYEIbJ$*Eny;N@~`Cxl@ePy6^1ut9E^%W#0d($!@!V;fKciJKei~ z_~En1#qmuCe`_ApYhf)d^S&wD+_y%r=CU3(9MrpaZ~tApraoWMPOg>KUE}t*kT(#s zE%$hKzjLQkZ(0I~$3VQDEtGPbEVF-L2MPmC2(fr#E0eKfk3Uz6`_9|BD1!_Qg*5^6 z?rdmy|MAut%_d>y3-}|#j_+G3c&+XivLXI$^=6(`TQRw8Y`)m#U=1N{ePvG^>Htj$ z0<6TI4P43rT4WTs*!3ZBe_BfUk1Tuy&GgYH#m-ntIK$U}G?eO4ovi>%{abqS>0ttC zS&%Vav_qr6P668npz-zc21ZU`4w90V!^KtI zDcmf$d%pSFCp4XKdZC1ko_lzI*_T@TcT(bZLAb>*zT%$_Sz5fcB_-gZuX7tKGboVy z4!M3(%SSU|oJ$91L#^`{eWrZd`0LuE&ey7^*<~9qiwO`_Bxt|f7q^HYjAC`Rh zX`E=_QB!2#Jx0BreJ{LoC7Azy@F}|ZSE0^@jPVjZIDXqY5w>J2mzI=$%!8)~bl*`S zu1|$!`MKdT`W-$AgUVDF(i~kruQ+EChMtysrpELAK*0NyRg^4$nVYMU?h!l@fK_`p zxH!}FV$Whsvcx`Uueed zXu&AFeO~KftxfrC*X6k$BZeJXz4mfw)R$Vio5gS#VXDS;&KCpFPD`Nx&M2PFtRZw| z#|JE=R1msjyi5@jYyEdai4y@FubiOrS)d`Z`?TgWqVmkMt@^M9xI{)yQa$EZRTfoo zOdlvNb|s<})hasucpIR*$duT?fo=@XRkMA>ZKjx(zx^O6)(KV5yemg>4#bZt6N5|2 zSm~NCE{M7If}i`fIRjOX(uZlSA=I^a6yAS)4+iVGnfHgysretu%O}S+PPs)5a})GO z?8OA4U~0&OH)KbnCFRI@>0e)EjkZb-g4l^A}L^ zvq$%)n2z$L8zilcLKlI9FjxJ-r?4nA{&i$CYCI~Wv7CmPIdbAyR|o)BdHiWyw2nX)ELsfB+7Ns_0Sm@yX-zQe5fg;VIQ4u z#;lEoVYC1|c$)J8E59^xQi2aj2AmQrjUg*Odd1hCa{lJ&r)(rSaMrC$uc;te*jG^G zQgGl-Nii8a0JoTlohE9|mI< z7wEm-{F}CnU+9rm!W5|}8Kw+MF0E#%gQ*P#yp~f6m4pb9G68F5_ueM%DeL02^{Og7 zY&VK_gx*wT>E@=)JH9y17 zeg9<~G{84MECus8z5M-AY0j6%dxVVF(UyvhEg{Sr5~FLW-l+A>_OP|~+s3bulSItf zWP~;lVmkb#W8qTz%0$xyu4vBTD8vrP&kK1ihlzRk7=|T=jp^PtUd5Y)mZtCZh4}YD%!r^}`YL7U1c#<5!OM!~1G2~qA#4|~KasVF zwOwWb#G2YXZwV4+q;mr5P)uqPcBn{$&khj9XfPL!WUWrNz(QPxFS}L zoI?Jmys+WJmGv}q_phEoL$FB;4GDQJMz9>9(&s7CIP}z)Do0K5Y}ECKhAmrmWt6Fe zUK<;mCP_K05beCHnm!|<`R07$NnOt@Jt2*CBE)7-3)3zmIT1FA=UsG8kE+_|k=twLyg#{6d zv5!C)n9r8D$$;Q+(7voe3+b)4F89EYB%h)rW?h}>zozPd(1qA}(;tA>2 z+N*e+7v?+{NYdsBc<;+R1soF8*wzTLQWR zmqYxg-P4DMmHsCnj=l>Y??iJ*ECTeAarYh^!l@PMX=$zH{vCM2R3?u9-nCq=dmAU;^p6|K`p6#Emuned@w`hR?uOo-Qk2fQ&~J zFoi$7C2^4;Ir4y4PP-IzD5%w z4EBrLPyiyu7e-f19$CQJYHm=yufEbrD=D(hcz-RBHVlzSA-Ito=)>`s_hle@6%sLV zg{Eh~wxEBl-BEtrZ3+mj>EAjVI^}xoKa#v;#Dv|>ohNQNdt?HK#s==kq+Jt^j6Hd9 z0yYhW22U$WOPV3;YnQZW$r61eq_b$fNk54iS(c%6*b|Uw;uvu{ex;Y882Y)*~^hJ{|Nd?)Bv#cpd@OI)PE%^Xd7y;{GBcJZ1i`FY{+O z1#95^w%-b@id{L2ET-=XFSHmu9-A)xSzcg>UqdlDqd3;`549~?H~vV`MfC1pU+lRa zMZpi66qmU(%02EKgv(i_+Q%lm3};*a<x|*&%R!jSuz1b1De?iVIJYT(&t5Ri z^FNUW$kdIu<)}^^x!)eBFJ<(BjQ=1gAIn@uCm@`p_nOO^;NTu=n3Qx>&m|1UFxE)A zE%b}yjlI7T3_=n9HGaRH!}5qDgL)659*8Q zN~HTnLIKd49D1bp(s=!NCzs1aLwPoT@ZKVPglG(E`hLZ{=#tdED;K_>cFOsZZXmbi z_Q%Us5PAqOWZ=i!etWdE%1ZH7&ErtJfsS2rY$uGwWqYkVYkD6-I3>Fg&DKXmL&ml` zcD<7O-2LX+UQRc=s+R_4h?+%~tqc{UN#f}9eyn;?s*97182)p9@!8CzrH5rr;jcoH zofDVCAog*Z5@$UHIaeQlW6b<(yflx?4PC-#V3|Xrg1W{q^sNm z;6wjC9L|zo%HbjqC|hlE*B+MJ>*@Dk-NqW!>UJHMe3>o-9wLIV)TC5-?}@Q85{Tmv zrnm12LCUWUdb?gF14HTJwC0!~po+7O=yWBYuasOx$^v{l&4s~02P--T8h;%iG=+r= zsp@{J4+FOMR<&Q7baP~C?feTAm3Qx1|Kkb~rf}IgUOZsBtJbSU2rL7o&`*$V}!28KN-Hj;v%?~l~KAL&nmQM>yS0?T5%mD4Ck;f$yKneXcWWwpq zs;9ZRb5R3)paBF2;;vF?air;YS}2xu6{UB@oluNbWSF`zi^!Q|8mdeD%CCAOV_r3* z(Ff;WR!$^qxZ-*F`zB4E`nhX8obfuj%i+%aS{jI7y*jeBZsBaq7>v!##AUQ3d`N{F z>OPZPKa!JTi}N3WB<h!nPmR` zrP-Mhq3+DsAvz3=YvspfI})x159tBrM(XgOY!Z3fX4g6D6ShvADyq9KF$a&QJuN8c znpc_aXGqI7om-hS8jSj-wpOU`l%j8>a<1^puV`AGI~W=Cq>)Z-0q@XC-<~a;5$iNT z(g||TCJAEhMEn&G)k~8+9=54^w_4P$Z_6pJ#HmOVLKJiHq$Ux2e}o&`?-Cpia<}d} zT3k&^y1|7wcv%!6hzg#(c%l1$+B^H7s0%obpX=c^Mkcu>oTWXWSU$jz9%Yc4s8NBQ zJ1ij*j>XlWsVlIh0-7v11|ycKLuIaMbU|vgg4sD)N9v}NhakGs0^=dJ(m273hwJq| zsvr9c+Wma@eYo%U^Lc;XpZDkei8!`wZWhsLANqc&p@#$_p6(|^TE7o)#Rb3H?_V&og+kx>uDiuzsL?3xuqVVZx4D} z7#Ugr#mar*>gcsrf9?Nm-az;u(&Fzmaiz%V@l1Tv`Dp%DZFjb`AWeLLYMvb;KlU7~ z2f-a+E6i?pQCiT*(37aQYIYD_J|^4N0#X;9y6vZ$<#ns;jD`iKt*8AD2~*FRLSpU= zvrlJ6dgwm&TDhE!GF^6KPE6B<|>!xbXxv{vCC-m1m=UBDO?qfgvP+I4>t9F_pkHPGh zDd(1~zuK9;D~;UPfh)DFB1wVG^h=?OeO(7XMwzySrcs z_l9dO!UdKrBf*-`rKqS6!7UQk+Ktu2gU1e=G_VIJlY#oNZGc6nThL_dIQ<9D0kIlP^`8%Mm~qg63!fd^?MCBlT|M#D9&srD?5*Cp{%i-P zTWg%@hU1Q)xO5gWmb71mz+rVz{;Ja&74?;wIhNtA)#C85BDrH z7W`3|TF&QgLrHn+?~~u%Ob*!}+`PQ$b+Oh9)=7xTIFuF|qtu$^Lv10)kkW`3&bN{^ za7j24x8u)j!gkjEflg)BBcej4x|JJ}VaYoGo>2075n;9DimCz%-Gr41@MVy#eLCog z;)gk4a-qr?%ttQM%{i`~&-9ih`VonOsijX&5A?jw-@!p*re-WRME{brKcL7svzwa! ppWbSKw>^{lf7|rGq(meBX_n_V&S*$%2-W>syXwVf>sGv4^cSW7CfNW0 literal 0 HcmV?d00001 diff --git a/docs/docs/assets/images/llama3/lingua-8b-mxfp8-1node.png b/docs/docs/assets/images/llama3/lingua-8b-mxfp8-1node.png new file mode 100644 index 0000000000000000000000000000000000000000..0e0bb9383c292dd9d4c153ff02cb6aa0d3de9ab4 GIT binary patch literal 228791 zcmeFZ1y`1B*EM?Cv_&ZvC`w64BcX_h5+bdXh=>BxC?KH-sH7mJT!Ns4bO_QZjidr1 zDJ7x61qM>zJfHX5d%y2D#{L0&jQxypKesn=d7j6yjsLf>A$PRcz#)6ZnmOg2j{~L8*-xrAFz^kN1eSS^gwbVtvJ=+(u3ae&c`T+ ziu>{3H~Q^lQr{^YwVTW9ap=_A@XMw4P8mUAh5iYL7V@V4&!0cPElNo(zVLz7^pN)s zc`Ek*{(qu38SifVpI^dH$O%~39{KM-M*LQC;NE@I|KpeORg2eK&k^7J@1Og)H@Pz8 zzkh-Fe!STK|B?UCn*6`7rBb%VK#8ORN63u$)r$My)ON9NNz{LDo#XJeqJl(CO}%CB z`plPH$?o8nFAX2muBS$e-F{&;c+0|qLrSWs;F&^2xqyhxKuN)~HzFrbo>c2KZ;dhR zJOA%(`x-V>L+g)MTbp`|s{+T)ZzChLW`^D?d^E+f%elSzx7h^Bx%jrMHuK5I>>q22 zjI5a$9o<~3S9v62hx{q}An&Jcdj+qOw441dIXg>?x4fjic=6(=w9C5_lp}rf^TiyU zoGvA*>=Ci-q7%N+d`d+{rK~&r`SU0E`}i*BY#taGNLG*AINQUb_;5a)S69$=@utYl zHd10@Vp~_2Z%`1!g-CvSU0vP3OH0FFa?NQXBR!Qu7LGCdo7}wlxG|jf^_fQudSz~s zvhM%14ox!9(pHBYK6iaRc%S61@$qr8pgdFf?XUSv_5aQf9K9GrXJ}|hI(6rc*iqFN zT1Tb>{rsx_{hAhXp1b<#t>(6C*-?9gl#}sj%`r0izxA`PowYmmWMw1PL~`OY);A<9 zOslUUbSKZ(^PYS6?rrPtew3Fd`a;|$=((`@x8dqXG$Xr;TE~9vE^3L9^9<@QzH9pJ z!;_m`X?qrb{aP?wquai{?tA;|OT~9?M?HUjT3?^dtTF6#f02{W&9*}`J=vQE%iJ~X zW>@cWN#)A(Y}Tz? zwD{%8ley?(MNZrPP)5E(7$kT_&7X{^-L z*eE|z8z{NJx$As}!0wipCnH%V$dV2cCR!f_rn=Ho7w;&0_V)H>Wo4b}OjH#zt)-K8 zo;%dNy7S<{?Q(K*1w}>ij*}g2?d=aDBA8WFRoh>k-V|^^X0L9BL4zoR8MkIEC+C?{ zr%1<)8-G0aLgXZ>Mju_4iX-t-PHH*S?_HIDOF3$%L`qzv$L6tt}81m5>0rw zao+y*vm+?H)JjKYAQrFemnyJ*Wl=@HmkGYbP zQdy&#kLs6unTttp-!?D{81b=2wz>0X4pPRNK3{m2vn(hi6yh}X>HOPwA8KpIO)swX z+MmpSwEdt?L0|Hv*I}~0w|n0WH+Q$SU1i}n7^-<55>mNu|9(l6ym#+<()Hg<3~R}} z-;h;buaxj!zrFR#m*=G(t1lvd_VnoK6xy35e8WCdJb%0KrQCaP!X;PTo$JVvH|iY= zsUZS`=RKL3nPHrn(k_<+o{fB*iG1@kOE?+RG14P;YMR`%bW zf9bU%pRn+=-D;mL`ztF?PyWzZ?ELiUy64LLqoFmu;ycGXySj9EVm55rniUX0H^vY| zGMmm_*Ovd|RTbYXEWEU=Z}cna-8+#RW@enU#u)~Z5ied`kzhMTwbU9XfD$q)cVEbt zQ#oQq%eeLUm3LHlFkK!iS68XpfIScA=H}co#kskuW$yh}u(!|YwIj>A?+pnF$yt|? zkufnd3!i=05G`o3@!Pj=*UPiH!83}oexZHB(BhjnM-rOukcMVww+jmke_8+X<;(Z(%!3*l8WO|f zAEO^VeoWQ!u6*qu?XRiNn(Uu_CZ;~8-t0`IKXKy3_0F7}oVEmI7HN+ay|)qv(8u~B62pKY|to(fbFsG39h2iRLNHOPZ z2XMB+e^%w%k2U^Xo-LQ`{5eqSKj^vsQuz4sNM6D!h>;{oV>V6K^Kx zDa(BP%7^D3^WV+}>=9O+%zXL&_E+B;f0opjLGdZAYiUJyXIq*CXqw90mnEhz7dhT4 z8~N1Ue!ARa_4<0b>#wOKR~-c^?l1o-(jM1{^c@_$J>Rcw|{2wHaW?UVD^XVw()QP-{TOOX@rglD#106yW{?;l?uKYm;>RQXVH zVUOtrrqHHv-ebIJa;}S1bwT^j_|Py(nC#lM%l6;mt-Q`M9L<(>Sy@?&x^OmAQ`4R= zUm83X)fJpXEj|r7dNS_ZS$Xf~G*zMmHx%EUReZ#-pT^1Edq=1-i-3`W zfx*5QS$98w7J+!D>8=t%Jf8+n2BK>2K*!h${a%sB?*sJs6kkd?2_osI8JEZWAzi+V zMJ?gKiWLn&HDNz-V%Mchm#FCIyklbz?iI75qNh*cPaUc9p)JW!+He)wO-WT%TPo&+ zS>x5!f4|7L=I7^G#jVLR>j;m5CVACpv5$GnpUH0~^5cyI_li;i4o^tBPj?;X)6b{# z_4OqS-u>A(ERteSS{SZn7pQlgEnqBt^Xg_1N3}2g8LXycXM})p5Q^8Ab)wQ*S@9if z&M|9JPq+{+J!11I`BKJG&0Ec+unS2dE6wAHTbu2WK zQVGtOVahIfN&;7rmW;=W#PUe69J0cXD7+uC>ow$06kC|lkRUa zMn|t74N;Q#YFO&(lsD4U(lW04Z1!97p<=D@i^~-cHn~i`)yw8Mrt&ra_JYc5^~lHf ze$NcQJnYYWT!Xj&VyvtgM~D)0@RKJ`s(xDe#Vc#aStsMz_aHZ1&b{$5DkNlQXM4Me zsAbo6oEqJpgq@~xMO|OAOg1^#+h0T3P1Aa-(cIARB>i%(ui1?ol#at%&FYEzd}c+- z-1U~X+vCp`##nAqUAq$6dX6?nS^Mln{l{qW%dxWVdS7zPo}+Z~Y9`Vx%+H%_KPcNQ zT82!UZ!;*{_4?fQ8|LP9fEmvfI81;lOi(vWnj!_vkpWvogIZ(datA$@1GUqygr7gM z>xPL*rIgduk?ByCaDHv$66eYrE#!t5BDboVKYk?s7NRo#n(j77u$$s{bp^OxnB@!QAi-E8VNMQ zA|ltXT)EQR(Ggl$SeWgyFrIyPc7QrV^NaPjupc>-$n-Ow>&pjIHIo=RKYunmdh}@X z=g;9kdh?V8511`vdiBP3tfZZPcC@*^{_*7GWHz!@HW8Z%NlCo>gG>4-0&&tB?HAWP zu0;LOcWDg!VN~w%k@m3T_Lf<9wF|O?jDF&olnG!o=O(VHpwg@lA3n@=*e+4bDznzt zmIVZ69&X!yeL>Fs-v=UxqL@d8hci7>I+8=NdbpLo=hf+lJ#yFgss zx6-jzP{%%h-hg=9Ts`hY=kD(AVglG%RpznUNJQ~->#J8QVx(O-STA}^ANR}G|4tq# zusbY34h{~sT0w<#BNi?)1WS^-}>G z70_>(*Z|B)o;UDeIjP4Elo|ULqkJ*3LPe#=N-Sdoi=$v`>ls2oO{g7 z+?-c#zRcI(zHE5uux!$eDSZ|{rA@89cx-hO%V?u}@1o0hz?lr*H*u5^9z z8@FyXU z`ELAovr71|sdwYA&#Ahtzho7t*r#5hQ&lQsJp0hJrEJV$`FV6S@4ho;$SR`_9@%zq`PE(9F`=}-}Z+v{g$I0pQJ_bOXH(Vz~=>xinH;Wz`=j`oI*aST4 z<_}m14K06XtuAm75HzS>DB4 z+ZvQO3v~k4_ZipL*3O{ddfB1x2>zU+t^Bs+FXF+=PTKun7?2FNv<{Y5yQtExA(XRZ z{_ZG&*R9Zf@XWsVvalrs98I^yJ#6fbc9Y(HOzpawnp$DTGU`yF(btYx)X1x{nosG2 zko<_OO4HevYkm<~_qF|+w)(O|V;FbZS$}4lky{;ao?NY#6qQ(u6}K6ft6KcoAI+kB zT2z8(w~f&ly2FcV&S>bqlz4k@{*C8#3%r}aXKrq;MgYCTtk!cwm1J9fe*PR60*Q__ zhHFk4?3MFyi?e#vCUo;_e)^0JYJf}SFw!+SOWus;Y~37^ZHZ+?UmrO#m~>S?-|AA- zc)XQuiNgeg0Qu3QM^wv)}wkVTkDBv`iJ~){)Vield(YoWuxhQ75 z$BrGlVPV0uT&^6+Po6_vjQj`yITHTR{%h-o4I4^+phaTK&CefK$<4~DcQUB*lE`lq zkvLJ?GurUXW%&2Zm&4I=*FI3SNTYi*+D9?@V^;iHvuOVK(sR}&aT8FZpeL%*Qt1kmNJ`3k#bkD`p zViF8&Vgb8vOnrK5F{>+26%sFdjf%ZI0K4G$qn8Q0>f*aP5$m~Q_wK_Ivv`B=pHo>} zf6u(l`kE9K#kw-_0M2Wj+uY2-Dp^7sMp`W&L3)O2+Tx)G1N(CQ|Zwe zuaq}zGOi*M>-$gMIdtd{$+*nz^tEdoqLw#}Ti(2R6$#0uHX2j_@ey@(BMr}v zNw`ps5QT{PFtdMq{%!rdj@8-nb^A4-myj#-j1k(YI>UfrG!xslZKJ(YdeXYDz@Yd! zJuR(f;n&ykubT68%01j~T3R-(d9JU?05|J1Ht^b`e$~CZ)yWvNAv`S18=dJ~VqIO` zYm}i2KO%|O|Jk>vva*sqIau!bC@CqaN+_PyG)mMmLc_4PB}Rr#Gf8bMp#1O=5FoD= z1FO)NQUY=X`HaiX2v982=XVS%+IC~l0`C5pFtqlA0l1JX* z2(aNSXi4C)R(r}l%NysJKdl)Q+H=aw%QLo=<4EHtlkNX`m|?fV!iCYKc(1K4FI3gv zHPsEhZf<^P6?k%S-4E=pRy}mxS$c0k7XaVEcqR;0pKNayEjbu_Me~q zCAM|J2PXndMVhjT~rINpb(9oSm;+QNz z$LE6cU;SvFGN#<}E|5(;BFc_>p-~u>q_SZXS_dJ|HMb*2jvNvdWjYsppeAVlNr|T5 zy`mS=-;E~80#nu)OnYOcKcG``?zHJE*l_pmU1`rX=~gF1;%flf0*8s9j>*s>+u@Ch zLHpij@h>hdOvv+Y`Z-vBaNG6+S`zEXTuXcJYVL1oX(6wV1gcgwPzTfOuL-BHR)n(EwJGH_x8hm1$LWo?uP=i-(PRo#q`_z z8gcm0O{s&j`jlyW5KR>MQR$5vH+u3cS)`?6?A^O}q!nkSfjmoH z*1dbH`ug^HczEPwp-rf&e7H^0be@|s^_f>IS$zW&kJc&V-}bkf$D~~-3S!25%%&(8 zS=rf_c>8rTjmB1tlhpR-edFc}5YkG~qDB)^je0%TZR|N;RWDahy!UL;31Q)CESBB+ znmf_6qx3cB&LC+7ahn4;y1Ww1(>}B;&fBje z^*%yJk^;Mtt)OIHy?x6a9f|zGe#UnQStaKd2+=Va*5u2%M}rQ^1%3a1 zb7i4bPEb&AC_wk{5h9$oU~y(W z@UFJMXD%z8J!^vVRkF>=v@WOy2a#g%K1D57HZa$`)jkomc(^{~u=Cu|28ILDRF;;O zc@%$u+&z2tB%}La*e~S`NNhUM@-pk)yHRgR`%$XQ2R!9RRif?}6iCeU7e|U;^w?L3 zHjMdu(Lz&A%~_omf@GX!-M&vk-Ik7rebY7uZ>1xq<*bc^1vQabF}z& zf^1nH0((HjaYF4uJ|)!Ab9mwF>+4u&rHdEOA(YK$Dsj}4rNv^RquJ1;rP-DtT>*wB z+rL0iy_$2I=}7Qzoaf}^90nB{1`Lnjz4WIxmxf=df<6q)(@6XS%(ZS{4Kbr0W^m?8*0md$=-L-Vj zIMb-~DpJx=owA7JpWK%(_Xi!2@dGdbMaRL%M>pM-er9%{l=k>`og?N(PhY+aL{O9d z&R=^fvi`ZF<6G(DlmI4Pombat{Aeb&?-fxRvy|rJQ!(*i-}OYx<1**TlWasyLh{+c z$oS!TQbIzFi{kLm5J*K5c;SY`k7UW+;N}pz_ah@C)9BLQzI`IM_V-Ns>ekUr*(TP9 zWo5Em8HWCN0~4fCF68cG&BWoIIATh7#+PpR^9CYVklaVWGnTM$XoA3^?`nJ@?Q&xF z-yT^v34Y_UfVnwaI#$6~>Is`5$gJ@;EAZ~KDz2+;D`sKG8 z3HCo+725S=!=S8o{cLk zlr&!>viYPmPRlVBH@A|jeBm>`#H6G;>gl1Ois>s_$@h}MK-G`4 zXho)gWd2s;fNZZ81*kY*#i+B!$ry}|8$`pNW6kKfU2cUQY|VQ`$?Ploq4}nuSnK1D z31GlI>r%(Gi`X}NIGQEWX%}mi+Ku>Ln0&g%%x{ofd!YF%(6Hn*T4@3EvdBml^R>C* zYRLtZ53_{};XE~Ep6eO5h7Ao3J~{M$K@V2+o`dbW?*4BvWHmE0vsKn(C0XZv0RYg1 zQGKYNC5$feP@7YGL=s*4dNs%Xa`+Hj|%$It6VW~Z$p;@ zXva!HzuJt}!1wRP&Dp@g7vKBCe=I<~OkdqFg=j-H#hd}k1ey^Olz?+_w0`|9==_Si)YHHpA?v#G7x>I{PSAJ;< zf{&z01mEQ<+GoOY1t{<~YdzVfF4KEYT&EH+P*Cvh_{+gbW|c2OU1O!r^Iq>gmQI{N_4g}-I2z{s{Q2{qT=VTf&QJD= zT2$26Q(eo}`+am6@K$noth{+=N1`fQRCILHTI9OxPEtu}siX~AJHy~P%Q&!O$t|SW zCAk}h(ZjC?MY9op(w z@~x=RceIEsWiF4lS}wgeNKZ?1w%Po(w>PDqKD~(IklxU=lrKmc=yv05Ky7j;ji7L> z`(~}oZ!4I%-==GHSTSHol}*zAh|P<4{;go8j98)p1;$6KNsyjuCS@-?cGHalLkqTJ z>jb?6_+fAh)QJgPZyidh?QCuD19%UW5CkgQFbK^Nx7F<{e|t}_HThKa9HT7xyncO{ zVE3MPy*a=8Q--1bSOKktXTD{3&>yB^d~e^reG;9c?-L|y>gpZ}9bIO+5m}Hzpf`(8 z$q#VQnC@yYYAoiLp(tH5H8nFn*)SnT%gk4S<)M+5p&0XR*}S>FXnuWVNjP(Q2#CUV z&8F?ugPBvA?}I9xxIe{k?h5SXtIybmdv!I=GFKkkR8vp#>hx^+Iwu_ zT+bs|3Q6M=pEHePa|Oh1_io<4eY;Am?0&Str=6V((t_Tmr_(=C(9nzK;F{>J03WD` znkQNPg3A4P8sFKaDqyGix#2T;Bgd_Pe#YnrL&L*;z?sCUy_703I&L4@6^CDlHqT_2 zr|8Fu3JO--->^xd(=R?Fj&8|~`ujcxR@^pi`sA{;SwhMvt4Dn}lKh^Qfq+pcDxi16g+dt4(jbxBUG}CNrS_hd^b6ayu zi=wTqAc4zqb}MRXwtx(?b#fXD+^^3_CX)%J@5{Pm@t{q|8)}<0O*6DH%30G+L+AS2 zySrJyJo40Ve^NDTG4}8kAP7IC<(>QX1sE3JnJ?-!$hX>qg29oR`&fZPC0dLg*tr@E zNw&wTt6KF?jLbbhGieL>DvVy|r8v=tC>@CsrwSj`;%=g$&pZ?+&ZJYOm3~EZK*TIY zh$%7lErSZxl8gGo!|ZPq*q|FV+AQ#B-VF8R;$P={tR7VYRedYiU#(kvUR|YiK`e0L zfv&u_4t#wsI36k+c?rjHB+x`_7l^*?XbZC;+-%@KZ{DT7#_w{KpSbKYKOWCGE(IP+yU^p9^U zp7YSX|<=&=!BzpTkCP{?{+25~HtqmdDuc>1c!Fn{a;_>ab6#mi6DC*Q8@Rc-3mhkRf$b3Zd&&hzo}6hO7#K+8?svDpGNTNF z65O_N(<{3#K*+` zV&LPU_SfflZku{c=a%>8={~*it!(GihH2e&q*y-&QHfph7t$6N9rww&h6#V{>LS2o zC3gF|&myP~W=)zxG~jhh#%IoKf?U5X!?--?*u|Jb{QMd9#ivqrGY?F7*0K`n#+zH} z=a8FiuC)BDeluGB&t>&rUa#TPdN~rLvRc4oEx$8o&OFpT)w(qPuQdgR4sajEDPbBn zJR5X>`|N#SKYf&ni)^X7CK1l9shBY4X@-CuQ{64q1%k0e(RK2S2SQ;Wx);!~L;^a| z!VI&HtPG%5VcBi|)t9KH-OkE6eFqe|?zur(b6yrRKOC@WuYFDNNx$?~+u zyPpVq`?biy2(S8b4^Dadb1AYtd^`}2N-}Tz><#eQJ5H0T+(Aenlp6Hh96UUJoq1#h zjt3B5hA2RpwY_dFK({DDxZ1#a|}xIF7VYiR-@swT^%+PU-dZE0xO zL#0}v=o>A`?uc;3-~;JhQk28yhl>zcL!0ioP|W(FrYw|w5ET&_efb{pmbRoj+K36q z2J!+p{G+sy3(DiWiQ3YN=J%NGa|?#?QEkdJgv3~)J~27jFr8tP0m!Xr`bOOQgCGD(Ex~=h&Xog0Kq{Fpd%aSX;Z*$jHc|M%?|M4)X)#WbMr=w$7bNWyy>8J>M!B=pKP=F+}jmzQ$*OjQJJcDM$`H7dJMmUbN? zi?2pfIG0R$9apH?6}10Tx3eM6I~zKvLTg8TQB1#*ch%f{8+I_SPMG^9rQd}J(uMJl z4*|h=bkeAh0|!`dP#})IntMPp4G=yc564$Ve}&ufZ+dX7C6g2zquRPUhg#X)PZWH* z1Ydm3_wn^@tZEQfA;}fhgdB66`RuMd2g+y5F5MO(P0}GLsVnJ=VMi!m7ytdTL~8UE z{((ij;Qoa}&RSDNg>GUU6fQUs@oov4Z&#WKF|kC^i)?r@Tm)QEp^@^I;}g|N+;Kns zfg@U@v57Q8b)v`DcZ{NsDaov4I4Db=P~JrB)ylzLyEa0U^0A2q05||5j&Y^xhyYX) z^1Aj9*HkVEDJcr&7x+fX0cINaVSX9e`*XUhlx}l%b@dp5^!f=Ff)$flBs7|&rg!mo z;AOzzYJan9V@5a&DWyfzNFa zd}N9N;NbAjpBeVjP%C^gGX;`_3IJ75;K(s{KXXaHtL7dpp`x6ns!_+{rUK zG_6+Bn}g)fX{0?G4e zCEtf6XA19HPl=20F*h*7A;H5(nf+Z%1A<{gNncn+ES3m7Aj}nDDE{*qwT!n`i{1oJ zT>j)V5}(n^tYCO%59R9?62af8etmcAS%+N&<5bt>`1p8<_3l*N?NU-wTf+I|>pdCQ zwQtB%H6&qs#-#>7ebO4%lad6mJH^NBa6r(i_Z>B8qWgh?^arG!?}H&?Ze-P(f+cNe z<>2n!4xNoH+-$Y^vLkgDs?6ae)Ph(<&z~%g1 zUUrB`y^?1UGof?iMj~HefKZ9Y>fW`rHS*SrB*1uGOXIwo?HeHw<*fT`O=?vIOPpcs z42F(rN&gWPH$9**^DE$*vK=QiBu<_jfuASpgc*%ZfGj%$0|WWdCqSkLu&KB%et&@Y zHid)+bIM22Ag{765H~e0g_Rv0yWr%~aCOlp`oqq^7b`uXlC77xK|59owE4*|QN{>{ z7e**~u^&*QfZ=M|;?BfnW(FhSQ)=b&va?Tval5|$XSzGsLB?(M9}`YX9XL7@bMr^& zMokf|E{W!_SrEbypC^=^O)o?(eYrIgUxBoT#qG3)#xC?-I}aQP0stOC{5|&f-|c1_ z%eLh3V|DpY4;&I#-rZzR?U@iAzitBQomJXdaANq1qoWAQZ7@EmqN-{m4)ot2?*}EC z_OPGw^&fQ2(3lZm|eKG4N4N>HHFRYG!h0> z^r%?XmiK5<-(m4#6;j0d?AX8mc)@K7{hlo$o*PycX#Td}{tKe{Rc7Y)p*sD)oRbeO zw&e6ft^h1-E*2ma28+Kqc>er(@>Ykq49l^mSi)_ZsR*&)2>^xzWe@_tDX=(#M3G?g zPZj}=v0M?)tO%}sKrVHu7pG&`b3(yZ5_%{yAJ^@+R~u$#W=TEkQ7_LTJgqVf#v+RvVW%mr0p-)6%)^{Gi!DC$prf{THD(pMw)utFZJxhdoX-z~ zG?4s_hx^_Zw`uL~$o8JzUND|4gh_1n>`te;TwV(Vm)9RJ6z4eL#sc@{|2Sd;xV3`8 zssl1g2yBHp>Hxm)15XRdV?a|Kl2d!3gZVxwN2-e8f&2CYxwm0`VEcvCL<6`&Kr7+^ zA^i}_)Z?0KC2mW*u)5>N50|nqB_I&&_DaF6m14gSzfwi;RZdE@w6vEug(f6$qH#Ua zbmqZ9ffGT{Y2LE<`YD44RL`@}mUjEAF7|P7@Su*4&K!l}1y68>L17i|^b-Y+N&)tT z(7}IcDG0Z%5Czq$$ML5_P!GXrljLgOgxTyP#KWt)-~F^wb?5=ZP|KgHp0phz5gt@z zLx^jpa2-fsS!LfRe{YC9{yIM1_V1jDgoFgr3a=agVL7$Dm8a4$t22Z4Tx!EFZ9SDh zI5^FkiV1r87x=2(o=`s)0y0&epCGTJhPX_I8P`G+5SHWT4`Qhu#7`VOa>SeED1578 z>o{LLNlYYKYU)dtgGj+fWA5chLPpY!R>kqFQQa56>v5$bneb%VAnBG6epT(DZKPAFl=h);VxBftDt3d1O@$yy5M47s@e9c2F``B|a zL&JXQ5lM(*G4rE+p7QgIi+TCB!}4kc4>oP1@#W>BEXABCC%Xjlgh`$YYuE%6WmG~! zLdPVNf6WX`iDU`T{UK}NUq;?a2QEg`ZMS$5bEv>kwSXpWH>Ru-$MVkLBa-6L4PdY>qsAYw@_!{`` z!rH^J#tk+P-CW)tdaIegsRB;d%iw*F@y9Kb<4-7Bdj~lf$;JIjGbq%U9HXbF|AHp$ zbDw`BWSQ#-RVD*Z@B(^w?&J^I!|JE*dh)|HAy_vqJxDVA-jfpA%9HB`8c#q4b=1Yq z$HzwqpP^QO(3|@m+g6x9RT}rIWfHs^IvU~dmRtD$-)pP)!n?zx$!5Epn)b>Aus$y? zz8CZLNbJE>2Z8$3phT_%OKw_UoGC?XpCk{4IiceFOlxcDs3-D0Tj&O8lRpnMpqsq% zba3q6uc?}^F(?d11?W=es(uWVMzg5N2Bh3#z99Z;z?3lbra11~AZToCoY$3b!Cgnz zn5p-ie0PMK!#$&^D*zjsj03yL0)^Vtr3GesBQ& zY4HAhyThKPe`oacqC)S5?vQ7E&94n3jBsM)1h zf}8tE_v-i7n`n;+dq&j~4EHs_85Qij#1;87XJBRjC|&*H`SW4?Udj6zNC;?LBBs0LAqkyD@GKbA!8nwCvz?=oKhYcX zY3k<$3&7DfWTfy&UJD?93Fx;51_l(mRFDl=$!kLu4@z8sTUzs$>rwBkzJ5K~>NGP@ zda_%fe=Gf7@BauJ=Z}O+j~wSh;7+)?^AZN?KU-wf1(#UnUY|U4k(;DbYBbm z-e(j5`Z2)j_I_xU`xxry<}F*|fFgWe98gtN_30mBfWK6j`E}8JChUFU3`4C$1l?o* z^1l?YdfI7p5XTe625 z`$U+)gD$^6X_4VAc#Q4i=^uFlX#jlAtF9uy*Ro%;OwN#H?&-XhZilB72U#<<$pHd; zn*Mt=6PKFiW;!@ZG$aSUb)kTU-+TTzAYgRykos#y`tYJ;PVXIOFWd=yx`Ps{8z;X- zK~eEqbCeK45JAg8sQ`+703i{tnpE(N>u|arRQ5|^i4jn)EIz0B;Bmgel~6JtOCiw1 z+@C*nXuk+=aZgXrsf!o4V@g3$Q8DZN`*Wb+vh79`A=}Lt`SMP}7+{C&Y*1jk0i^6H z7;w${XDlo&tATS+d#Yg=2_kecaQ+tQ=aiLcVFbGgXCmPuC$I|M(?vF*IdFyb2qZ(D zlNq7%=gw_|lS8lELk0;4xMw?hTnR}@6MQyWvC%-!_~hws6yISORN;K0zutWw{1Fk1 zcmS zuJ;v7tElrD+-!5G9rA20n4&-ax^T9p#$G1Z0fKnjR%J7g@TV{FflQ~+jyg8;L**X z0~v@Ct*56)kY%uioi#8ZUo|~E_nI~jHTxPkIATKNU0$BawQF8zBCkOlC3rWDemCNn zqWBTm)S^3s9wRldND(#^I9udk#q~n>YZt-Ym=O==#BQcX2p%&gEFXa+h;g9z)`onB zMT~eJFVia+xJ5XfMc`Ge3T`CtW?HXh9{bzS6lW0TJ44R z!&vWp^Ly~4-_Xhs4G(IV7v#0m@ZqI37OR1hrKHB?Hzg${wjiaj(-qClX-dZL)WgKzS@&NVrZq&mtOH`3FJE8X9upIj zY*;LomX-!!6ZiV{LoEJlr$uk)wkni?`L0MUX`()kp(j$Mo zAu3n?Vh{a>s*GBd zidzCMRz*BsNJPX@P|Lj{gmaBI>=#zXuJ-Q2!a_B&D>34>Lp}_%bR?_6GQ;2H>FKvK zOrL{HtM$y5iMp$&~x&U(4~vM*uGF!iGdfYHXFK@~6-Vr$!%MZA+i@f~_>-I)86r~2IwIV+KZ zB>EN61bXddqvj@o!8{9aY)ctE@3yvbSxn~e;Y~Cb`|E)RDB%|JkN;JxKf-oJn8SiS8IsLMZX^4od1J>vtx^ zPbJS_J!)Qh%P%Y~n72eKYaj+S6ICjKmi*YYQ(%7;0+Vmh*KlYw@|av%YFJFnf$`SZ zVCvmR8E<LZ#Txn&cS>r zA}UJcGD4L7{Fw>4{UJ8p?pWjwjQ|+mz3`-m;O&4QSPQY~X;c&y@-Pq}q5c7A5L64_ z@#Fp{3-P!A=Mkv5NgTDTygwEGum2{V>p$P|P!sVbw&8#O+<*N^;Wojj{U2Yw#71zp z|HoH9DPpYhfBv%fj=W3%+1dZ`v;QCZNDoP1l^stMYF_w41U`j)i4EuOa@?W8h#-v4VSl)EcEyZ54XtTu`I*FH`gEN^yXjen z)F_cdL=C;e7_q5Dvlswin30izW3|S8a5Pc`HWDACI-By<+UP$x`~-MQ#&`o-HWScG zs9G40StCqmG{;ptE(EPyL{f!k=mXl~`Ju{91NLBRT!Tr5cs2N5vW<8IRk*% z%LFY@Hd@tpR$q4Efrxod&~4H(GIo%&!~Z;RlvC<1wjaBY5K#zmgg~C?IjZ1KN2%%R z>Y9W54tph0_ZsFXg5!zcH}pf+dmFB(al#$U6456p9r4P}wzE6-Xr&G|F%x8*x%v6i z5OU%Y60~OmJx3pX{~uB|MQU42gw1r{(3(#RH~VuUYiyR;93 z@B%z(&Wqsymv^u2^JgEWBl0kRYF%)+RH6{V#=^ov@F(T#YbP;l=*LO#O-hN3`&2q0 z03$HWz!m?6__&PAq+;nP6Zods#Bo}$GKh#3cMmcWtV?P)9)FqgmL zg{*iLS~kX8MX`ZfEG?(MskL03MUWGCjnL1LD zEwEP*gB$XyvA}TfUS2QXd03HnD71=aVAt5WfB#47Sh%Y_*Os=C+Td7KAxrFpWZ;Xv z1U*SWe!yKy*KStU#}I-Ef#Q2l_F=-2{6xVEvr};Zh_Da}JT+s)Y$7+vKw^Ukc>$_t zEqph_Gcy4=F5ie_16j;$I2A?;H2uTq@Q4{r#HHjs=h36_piOr8*YuR$6UXMbOf%xo zpp>gbZuewufE8zW1ppx)V)pmJayg?iwy;Lr?Z45H}JN zu|`nq7(fUtLP(SlOZb2^>A3^=1b)DW!mDryc%-4)~C|}sLus(M9idqi@#;fQwVGDiu@ZqM~p#8My@nI6W zh9|PMw{QHFpymRixzgBk)d$%76fz1dFg&x*0BtyqA6K2617!`<<;4uA=uoxZSGjIW z6y(&UV2vcYgqB}rPSZM=|K<=B+(8h>=xT@wc&jftTi{GO14{PT3b5{fa6aZp32_eY zSRw)iYqOl^uYrgg!h3Eh5};ZNcfzbfc6K(a@C_BaV?IygE9*T6aBP}D_*W2YU_Q~7b0)fJ1{6-M-oJn zR=L#ij$reIWiaUs(*iBOFb3JT-UkN|EG#1_`J5PA!WizIu|_5}0|0Y6l75M^U0LL% zQJ~uY%vJ5(o0FW(RVfG_+8j4746Dv!4>wi(XE=8No%vm2bQ~!}MPEOf7~pf79eBCnD$zfg zUArvo_qdF6`#Vm^dn}0D8qkaS%5|JK#oqH{y47WMniB$ix4JJ{ zf8u^<(_a*BSnlzXWeqzG{?49`bkcrQDelc5^ANluG11u=al8pW9?1n9$_Cgv8VH9_ z)A`nFm>TBBU&?vLfEm%rG?J$gX>VZO-xNU0VWY0duQivRBUECB;yAdxr#$}CAa_- z2jtUO^SfNQD&Q12PK?EG!LJa_KTbGiqYcX5@BiEWs~v2vI=XPJOrw|L#{Zc02u@_- zO@CX7=ep5Kx6gxAemx2X9$Q*{MM z70A&NotQ`s0D1ZvmBbb{iVu0vtVk7jBT8ui2O7h(pmOof8^F_73W6vhY%MnbP>1lG zKG>dN{E>7Ym6ih&LrAr#V?|ZJ;E-np8Q^8S(;5k!BxuzmSZVh9_3Lc+DGXpikgi_I`_M?Z|9dy&OwAEO1)iY>0z|*2 z1TMfr@g(IQY5XEU4DSs8_(6}G5Lj7Ri7o+r!-lz$rccPdukk5Z1(^LUVvIv&-Tv-? z^e!kY{1D9vdWsMCmiLobO^JC6LW|(#<%R#ASb1w}Ym{E6I4kkF!v_h(13fEqj=ajW zg5iE!2Zx84CjcrW_%b8_?wPWSQu!*l^?-nR*x1t?A@#M^e<@?Mpo|zm{#{6K-dM%5m!-l{Ektoh z&Ol_GGw>-JjZ1L#5iA+usH*mymX4xS+rC(gA7A;xsS=+PHG2b z-73(7sN*IPL~~SBR7HQk;h3E`ZkE`J#DbP%VQLk)q6#v>PFB`r%AGI;I4X(DK)|G$ z+_^f=pULgPTEs2wvh%C{YK z&MUi-TH=Ns1Pj`WEssLD`KQ9-v`PmqzLfI8JtEGouJQO0vJ0flWX$|n50o^x{MJEU zTbvy{2+II@4vR;;8|gD8Ig}shANZ7#=2W6$0^rTZQJUVAVGTZ3mze>Ld!w z{(a)Qh^g}$paB%&Ona4@N(+SSINe2$fz(|ySEMrMZ#ql|N#O(lh(^jp68#1&c{`bz zYiXl%`;b@+2|EF3H6Z=7_z2>D3!i0Pq%|}#4Ro zPk{laoipDI`$e9uN5sBgr`3r67$vj`2T&0hp$#Jtn0UGhJ{?TH0#Ga1T{9D?oozX0 z%5MJ_sj)oYali;rjB#)RuHmmhe=+u95Ho$?1zE+!#8_eShCpl!UKhqGZIk!WXtk;V zYl#b)0661->zx12u@jIK;tjiq2qTV9L-!rrGO!aQ@)<}yn@J0-%TW=*8_lF z8&=G z!-o&g^c4zs`?^5Zd1u+p0GDzFNZU(jtDqx4`2zb4>gQ&*_zm?3IR3y;k1~&^{{Z%p zcya6wx0S?4dQ&=ZrMQ>nC>)C0Q3Uk#^+!N<>B0Ct8&~>nT6@wdlnvv{(MM%wTgj3S z@!qLqWK&t$0~oL*5TY-9=#>~OL_9kG`MKYEuq=>}$FROW zIUljnii9_%b%_a?nF`nIc;kK)#E6Ndr9WoeAUDJTNEY&c#m6crD7-@b%*j^2zdlU? z5MmH!;(Mn8v(fL$}z9u8a$1GttN6ToVxw8DJP}14?q% zko$=wkS~Y0xTruZ6VnrfMH21nH#EHkFms_^YQ4IpooBTmTnuYzUw{8Rnmpo= zmA^t#us30DFPX6$7PX;?iERid$Jv4X0CL2L$=ceo4Y8xB^3X8$*+^;!+2A$tN?{dDxGQ z3y}D3;wFQ%pIQ(xk6^41F|GIBdjE~qnB)G!yu7qBu8S2ACoVw^fFikx7&IkrdP886 zap4V?1X$x~P3F|<(60`v^iUQ;<5AqX@VI@K_ z3**}074(_K;58#oJIssi2u)n9LAaQSad>D%xGu;5_U&i8%cWXPD9Zg45;!~Ou%$%M zA>xMtad3r%1T=16sE`A5;x%zxjq1R*TSz{w>VOxI@nvE*!dG>tMwpA%B|^Hwh%yZl zl;cd#PUrzjDk_Ac24g_=o)hEHNdXWluuH`FB`DWZ&{1KnCPpP%V&#G%?6$~UL-USn zkM%3MJ{m&mo_Ni4H|#(}+dGBW0O2F+-0y&wK~f zgkXM9joM&fBPN!yS461ULFj;=Z6lUc$CZsXw}d)bH>R0eSa>&4umjsm2(QiSQjQeq z|CG|XkuSmJ<#G_%hNh=!F^B?aUD?{Y9YF{O1#8kuqS}iFPxWjM&T*1}){X zt^oEnXA%d>48!Aeq)gr*NUkSvj|eV_;GOnUodP?i0-Lp{RgZf5?UwVsrML{?>*fg< z>mGbgzVvXzRa}o@1$xM;%A1+dll3NOOTwKDo}naSxv2ge9_TUC>m0? zJFE=?0s^99V`FE;O{z8%gH%U%Jt2BO$z`n=$LBb9R%wrJ8ooIHNK zBY>ETM#~P9R9VEOQN#x(E`l~(1$|;d@CcB)X8`S7_<2QdG_!8Kk9XdBA3oi4CVG=D z2lwo00Kjh{Wucfgjw3F?Uce6;wtw4FhM$3p0%ms)DIe$b3+RpiMc$cs_1v%T{%Z>1F=Z&3=RzVBnKDE&G>Ayj zz*DA@p%fAsQpQ3mB&8B5Lxx18s1%vXkf|e{jxP`~3FyJnH-X zeBSSSxUTEIZW9$epB|yxUubK4gJU-vL$^f*bh!5Elaqgra-20?2=@;g9FqWMi>Asj zgxEC=*4AFY+YX%ihhIP`X@|YYlSh~rWVVSMXUW7xKO{86Mxs$Kd;k6{vIy76-o1Ja z!pDFYl()TgZ?{Zw^?OiqQ0Ju((Vq@YpE|Wb{ZnL(XYRii6iwxi@o4V#u9d9=-h12N>4Sh^E$ z2XDHkJEv#i)|jmyHejaMrNFUZSW;NB3c6;2Vm4)@?t@2q$v~)YiuFvv1Q)=2Csn*# z9j)x8H)4eAM<331R!YjoIyN$-(yyxevz@g*mNuvf*L%OJdT06-VA5VOTbx3^XA-2R zjqFx>)eROnh-38@PPByCcEDc>+vg5DsGXu5*&Z9}c4DPN+PUgk)~T z}(cQ`JG6>)2fT^)1+_EOT=HfSHRJ6NWl;c+pQI z=0x#B3gLzS@pZb<{^zyDF-LI@d;F{T$|1iUJM75z#ywiOR2tOOzUMk*p`5N`Lziqa zyLs)GvV9p(t`3(2gcngH!^>(q|8~inup~QSG3WXu9cydUcEG$L&QQR%$b0!&@swcs zD_0%98D$VPW#+2GMU!&+Q!1^d|I1E-N*Fw*CTvdopL=>o?@hyfRr@?7doz&IzHk?siDcr=sC2hkpQG_qXS2m9vpomKB^;bl5EKxQN@p*F+W+yOblNeb!-@mumhtE`SA7I zH$no!v8up@l`qFbo`z$U#B^I~aPO;nC9zA4I%m;u%ssQw;OX15$R4@Mz1G)OyI5C; zG;g_Luf;*;{WT6QE6d}Fy)aul*N2zTJIA+{84OGQkz3$>=FHEMLsJ|mbW`~1LCBa0 zRUO3ulfqHOb9d~vbA`v=;=I-W0&)kdqTTZ&sg%7N3@&Rn!{lb*%F!iBV`I)5FQeq` zuIV{t1H?+~f}2~G@W0)64zA|wN|NwOO7DF;zfep7ur{znW#mvrL7jnC73BHQ4PQ@Q zj;bOOFHz!VKKU&VBruOM))E%xQOr@J^c`7x7CG{WC-tzUYJFwpN-Wy}uQNLS?Aap!&sCH*`M>v$a=sG{$YqR{e*eC-MfUcoqc;4`nmh5S z+g{g@8Hi=XVU7Z3|Dv@K&A*t}+!BxMl;+)=>#1VYnncydKx^Yx?~SvjYYI+&yB&>1j5miL_JXdDW>` zY4w|qx2Xc%e2i&P=9{F^nm;p79{X}~s4iUqxHaZ>Et9@QJlS?k$uO+wJ9W4Pnhc4Q z@GSNZTt_XL@vZl{5AOn&Dw~?<=x*7%b+AHt{*1*u%+t+Vt$5=I_v4`WQ_?wYz0Tvg z;Q1t8{NnltpU(b#cPb4H4tlSwuGKg?gCLdM9pUo2e!c5>s08>y_ zd;v)UskZT^TE>&BxNtpeqDgv#?dqUCOqE9)aI(Xsj$dX5eefp1v-Lf8fYM`2W#P_lM@|1Q1NRSzdGt^#HxgW6{!&)&_ zG$e{MIFF9ZY1b7+CNFwc-(N8rzDnuTX@d(5!`-Z(^;nx1XsK?l>w9>rLZ3f~g;gtx9IR(Ph$x@_NGYLqe^}+^f+Hh6(rDEV&hGc-gQ25! zZ28sU`uZW%4huPZ$O2*>ek(CF8KR2;gNxeS&Gf)~w%@W{xO};m+d87^cKCod#l<~> zd##&n`ES1}!m9km%!ZE~odb48ToJLTgGKD@6vY|Dly@R0zgK`mmG}ULEg~=S;9Vjm zE-VO|e)!Ab{x73?51}+r>hpbU9Xp@-1?GwR`EbDBxEvPHRdVO@)Ae|MQRRf8_uYLb zRehy=c!v!i%?YC-Mbiltl)?a6=F_Gu zfU;Ba?U2DuI06lGfra^Xw$!N1E+pHu zTU^Oa_2*v@?itIGfT?k4eF%kMZ(GEJGUs&nnGWFNh3V%jPWqh{G|)ORiI7nTJ@2rv zkB)qd9hm4O5ysXeQEK%bk*qgld4W3dXP_CwOKSn-L2MnE_Z}yocM`dvD z;jC*(+PR@y&vsrPVZ2RVX~51IED72&443wpFH3?-2gsvfhUAiIsD}6 zLkUhBN>`62WsIktw70twr3u1Df0RxqxO?~R>m6RcCXgpAn;ua#T;oBT9XhdBu3RAm zGU27}r#*AupCdg+jt0SQa3v0`^U{)X&8`M^$=IER1B2>%RuAmfy}QHrvvukz-A?N= z^T1_?;3HF))hJbjH(iUgBjWX5>z@WMCc)H(A;%Xafrk0kl}h*W&%BLX zea-H@d-S;d(eYiY9k;ya8Bcz2s47`WLAISu9ZuW(m&{MwdM9Z5n-XM>)nCe5Gaay< zZ>)JWPpj$+xjhN6b z{P4jaC`3SCF@*0ew&$M_&M>-Si|3oS3pY%Lr}zwzE@tn4KjFYKYoI*u^2&`MM9?xm zbT22Z)=Zf+$qQySciuX{tftU{DFC2n+^ZHA{NBI+6;`wie$C0RJ_H;#8>QJ{eRpp_ ze6jVF*&>zG2tEe`dWu>VHy8VPO6|6A&#Hg?I7VNj{C-`fSi93?hRd}Y<>?j{(X9HK zzPtnPqKB8RT&brhRo(cIV1HlKlMJt%(e&sFdEy3lNcN0e`F%}K6G|u0nC+xR5s=GR z3ovL%XsG*52W){GD=Kg^NYOWbdi5hHl>LK;?*HtYU*o(*&w;KZ<2 zJ+LkasBzf3yXUMI^nk<7anj5hhHp$@@S5ZE262nZ+;a3D^y;}X z6!tGABg1!a+xhsyP{kl9u9xz`Bp?!$w$aWG15Pcb*wo5-)$O?Q*`(-Y?v zlMC4K`O7*D0&?P%6Gqi7w(i{*P)y(v5JJ(LGQw)EdAIo5b~W|MVJ%|1O#bPXW7+5z z<{XNn2y47QPg-6nvsVtNcItP|+i}ZQG`7?hA!VEo8@~m7o=cHGX;Uv{=j=FrdVr#j zx|bWg5RYW9e^u8<;1wd40RsclI8>;(M`D;j&i#z@nK;$cQ`Z=s3Y5(@tJz9nFN#5e>h(ok!oQ@IVK?37r@6UeP;!xGak{+RltMpA4sX9D7 zHTS^TIB}&z(g5|{fUMCrHVOrWEdydv1Tc9gXoe^*S^$cBH!G?Xl@E&Q1-zRCS6jy8 zE)D4Z;xu^S=j;#oPT$r57I{7+CdCEr>ULIEfWPMM2Zbq{Bn%>Cysf5wH!JJtsZ;(8 zg@c@%Mi^L#y{h8qmoDE%*9+US&T)NY)cyIx=l@&iHlZeT<=X7UHhg`Nd?S*UVM7_1 zqqXCF-2C(V`)G91G3WYUvP2yu!Cv?bt*Nyd$G>+1LrSUfjhxbV_!iLF2~^Ns+wNch zRRZ@7mQ(OUFwmX=jv^jJD~GF`GPLA~q;789T0FKOG6lGy1pr-4h~Kv&(r}LUj9$A; zdiFD!x<04L%*un_yN|#f_PAEFWrkHuz=jBOy<`x(8@_EcZD5$lq%cgkGkOxfZZhhv z>+W@#3cW#xPXJTsTXQC3e$@&(s%5)&4f%hdSr~%n$*H?uC#&R< z+t-rJc9a#-fjJ;IEhGhaH)tY62}n`H>$|p=@v%+U);h!1P!u%qi`UQmN+&xJ)jAf9 znusf!?A*|EN#)}!uDPfV#ZqM&jLlEiFyZ23m-Wcl7a9YO^xcBg3ZO^n|Na)TYz2nYg9$v>? z7BdbB?lOQ-_~-O9BGa9;&29Ep)oc^gQFC;TOTfBbGedl1nhrr!Eowvb&W)6nZxC2u zh@0_X>QOX|GdDLsU2>3{6^70kHR1@8`0q@=yUl9bYd$4wV|8#7#N^D{uTBat#>T-c zb|BMMR9{=%3y(fY%P>-cnqN7t4!eh;s+-bypKos-v}%;kaZjYls!gg7L6 z#UYixS(TBOOYQ*Yzoy@T0$iAE-E_uR)wN?E8Xo^lzVIm>%b{+ms@m7>+Ar10*`}Uo z;Z!soBW4ZXeLe|+j_Zxde|}eg?*7+{<^|C&f7TW~oRPch>8S!Pn#H6{nwOh3Z>{Uy zz98!WpfG)HYGY2rU??`fd{}>d)9hMaoCF^rJ05%!rRJTiY#fI%@d^s)jORN7WU~dE z?516~Z~;eQB}SSnV0U0->L|D%TQEG56j5=iUD-2ZI@*0x%@(PylBY|}BiE5Q?!y_$ zNY1FYUd0$RO1MBN>UiepwQf?E#4o(7=GSq=(418grDXVuG$S*RWb7TO@`Lj2Go^WxElnCV5-q9?f)UD49mm}IjM_H(=O2qJDJE+_ z(j>*4uidcmSnagoBTMx#UfhnypAGXqlZ=g1F$k6jiH^P_j?|Qsd)kOe#LUxvxId3s z{lciK(}d;i*CS`&zab6Zn*BdjCx7be+sU{1>zVQQ*V%(PQ~9L)YPWu)DI0eskp4W;u^DnLNN%Ye>V)EvFx0yHUjxPA;<%OlerVAqIRt8iL zc@j6^TB(nQyRo^s3fbG-Xg=4_zT{x-{-x`C)!ZGn`SseL-%b_Y>=jlS7m$!n_xK%! zYa6too2X|C*12%IlH!e*oFxO|Lhy6H^yijdQQvD}OQOnSO710A+O1=%F zH}B$=E8X!u!AD|aVq&7rMLs}11*y@#*$=&vnOk0R*U2#G#Wofe+7z^_?<2b`5L<3V z?N?X}!enV__;f@H+Pu%;5I)U=cXb@0L|6QNs$s`>`rUW0aCW{9iFrS6{%H6+(QaOP zd=s%<1^+g%|C<`0CpjXM7`!MumEuX9n3yQA`_IHsBaS?t!u9Y+&%bmH>0^FoJ$I8* zJ>&Il9=V1*$Kc5z;=p+nF!U~)^>4OWr5Z#1^7dhL>*Z_NO`+)XF{P)Dj`@Y>XmA35 zT>h{Kkxpj-nuhc)7uSB(n08oyJ1I|$JXw#UCMErqh2pPM$~;~-L-*b=RIlP zX}4jw8t>n};>ycbdpoeV>Xy-uK68N}3q^~S#X`c@TFeW^%` zAo=BkQAjjdJJ|ks-~~#!BTLk$1oxOv>G@v4jg+y#Cgwi2U><)`OuZyb7)*W&i|T^mBq zZ?iRq=6N(F4K&~7Hrx*{*Ukzr3qLcHFO2&^+B*cMqCknba%(TktuB>!jiBGTvCKJf zNxrynP;u8)#LbO7;{PRcqs*kyUWT{>Rzt`p%=rv2*`BI83DfiPC6Gz2<_1iDpHXY6c1b$w;52}SXfylV0W zzU<~@nEn&yXp)n*?l$;-o_AcY`o44j>GzWZN6l`fQ3RyeOqhF+HfdD|8qsGPJ7s>l zI%^&z?Dbv0bf_vN_##EV0f$rdqG~e%6R|?!>(}1k`e7UX{FWvQ;rljjhiI&*9mfDC z+Jpzwn*9NSDP*S?6tYZ(-M!n4jA{{+a-S{@$}W`<+2(N*5dWQYDqTr{{Nw$ry7x)V z&sSxmg$SAo_4@RG)+F#1KX)4?%MEljeeMZl8cXw}`Tv$V8RE~8rFHHKBMRf7M zD!#s-RWL|MJs@7J*LENP(;Ru+VA+#Z**}lo)9d-gugeJ_Zp2H?MV5e2NH~Bqo_JfY z8tych88~wYrwqkBh7P_Zq&J$$j4l5vyr|auyeN#yc0E|D5mTyr$r6h!VDQM5 zZa^OqYU)!ARNoOwgg8YHL}X0~-XQZ){4}&f6gCM1xNsp?=q7Z` z3`D4?2gcI$qr^9V*Tfx18CR6lI8uK8*j*gdjhv85&W(2SrOulEG zJplPn#fF--&{Ut2S?b*~?jzv7yuUu8Mf}13SIXkF`4XdH(EK@L(;WvRI@= zr*Z}~Tsy7(2-i1F@hphNBd?sc%qU{&@al8N#wIrJq<>p6klW5CfVuf!cpPUrX_Yq? zEnJuYUO7;Kl4Wg(JI1B5J%su-p5DQE|K~4X>=k)5Cf7W*!fyj9BlGDn zcwy6e*8wE%1SO&pRr*kd?eocUaO4wa$v6 zAa~QM|2He?@becBqg;%NFCy}GNUvrx$Z2VCi~{I5yqo1LbV6+k43&-8IU&nHXtVkO zo(@K6y(*!om{mSOi!j2!a;;db=L8*x1=M{2xY50RDN%_&MU@xA+4DP_^&ciu6`Ia? z`t)6@NZ2WkGb5n@&S3kO@%o>}jrZK0Vr*>B44{xw0B{wrJ;5`2FJI~6g6tZ5eQXzo z_Ack|S=!nj%P6zBg@W@@!RU9D$jSGvy5RU=OC1{oDJ8~SePO!1-S>TYii-}LdAtWN zC(|OOriI&hVrFQ<)UnKI7KP?C=nq;h550~J(YhoH{yCuhP$mB?tM%p1$j<`n*f2#= zPA&zo@+&6nRi%eGhJ-zd)pq-a2W83*Qwphrseb zNgqF4dgS1H2vsQZtF!#LJK13{+|gHmliIA^4{E1zNdJ7!@YwL{E6gL-aI%Tb&`-KBYcUhdCY-bsVfty{JXA_9q-Fq;TYnuqY@Ax{&R zz0mzxPEOAPfvNL%GhDy3G^d-MHb+Jzq_~Y5mmoO=qsb-_xL~8c%|6d|Gw^{T+9R zH!7PIsI$(u^00RW(`Id%{Rt)vWHThxUXed)3h+9LF`+D$a7evBk94c>GV5$!s^_P% zIV1m_e)^sQEuW0~@1B<`n~ec3c7EEh20UFVROd*j&AHA&G*yFy)3Uceex@I1dFuY%ySx0HKj*|Atwf{7^LAmM zx@V-Jh^dzc5#?|HUQY(bm?)Ok-Mhw=6L5u4UASJwRHfu@?f%{jj%LFX2jTvIWy7CQ zi`*ytocn0e+jox*Rrc%;1KdgxVAXOuVU|yg`Lt(5Q0=h$?RM4|`=^iRjvpz-ez|i! zrKY=mOwM=Eao}AbNo*}?9UnSpS}cI1cCxhqBn`RtM%;Fsf877bnhwb;;HBY|u99B= zE6=1Y4Hr9_6#tE#H|hKlp)da5(sh!^!x$>01xP1BiA}offb_8Hj*?Jxq%vBvDI4lj zjW{PpPvV8id{Z|n75dVf09ZDTj@uhFX?yzy-s&Pw16Izv>XBI9Y!!%3?YC#kWo%5` zkU;7Cgr>pGWRJ;}nMTrv!@Mrw5!)Vo`i^E9ZpK1pl4y`Iib~-wQf&G18tE)5v&D)* z>;ph~$CBVhoe$tdv(y&Ji5Prd{dyvMkgp)a7&F{BsBLHeBIw=w_W}lrAqbP>U?#WA zl)IzW#T^S)r4HY|=X(#U7&4PC3M-@s^+fXE(&`vG5Z#?Sx~*NimIX(>(F%z9c^Eku z)xRvg5ZsHM3BpHsuX|j|EcIZK3W0Kg(pt2qTwKeJ3{;Z!xve|U>=^WDwSX5cnnaOp z(lmw>K*ZXfnuE!O?jt@L*fmsY(4dZ2;j=x(f50AX(M{k{h@R;vM6R2csOhlQ8=(1z8n*P_7P-l z<=acYpt0uE5MW@x_S!-pZTjyt7}H2}cMWA~ojOPN)z`0IWyD{KBDm(ive5tNI1v)x z&JFtq{~VcoMa&`NW%qE`$03z}R>)}-Whb-ex|vKF*b)c;xeY{ye0kMk+3?E!mIVkR zC!@-AzWhoS+=_uKk2?&$?p|)Lfv2awJV+po#>q1e$;j}YRr#yvqH)a1+LN6Blr|f| z#_=&vEY-yipSW%ZoY1IoJAVDxly5W}WZwkV_(77nF)#+);tMA&lB#JA{wI zRML(NC>T-Faq%}YZ~x=lhaEip>lk?uD$)S_g>j=H>l?USC^k5p8$2r$Hekr(p*j== zCBFJias8CV6W*B`9-c%_e@=z07Lf5@QS4FQNLx>MjGCKxx^QmgJ8|(~@=5l&iinf- z0KBqWGJ=5SmP2ON_5$h|LFWXrqdPlBmErc|OJ9-6l00m9x=Gc1?#8?iaZCmTPgwRZ zsjvw}Ue=)@vQs3LDqDp7GRlumMJT!dNJ6)ZfZ_Iqwmny66pGC?9I(UuZaDnxZYRY= zU`8-K$u2_RQ8J*2LeXx>N=lgO0F9Z-T&g`jznErr28j^GEux^YfPBm=Z2}VSin3Y8 zkx((%yB`?-{Tf-T2vx1Ki_2Kdnq+MCzq(&9+4(vxH0I&@55TQVGVS6wSicY(4N(wM zoE0)7BdO@yjDb=uJ+YIVI>73pm&879DvNj8kVOPO;Wsa!z78p9E)6aAaf)i~^;Z!2 zE$?XXVvvzX*q%QzL%Mby+qSd9@<5~7je)w#k+yLZ5Z4wxKD~ZiI@sPZv~dlej`Fo(il@#9&VVph@(-#eBk4w zVPPBjlx?!Nv<+h%DvWa>kwDr2lD#M;bv?fIV~|lm@%#g)I3dE)Aw7^yTm7t)+B3=K z6(6t7U-UtZ!jchLp^MU9Hcn^u6$TI2RbWt%#I6)tt zae74b?%gef#iW^*TUr{lV4B_$&0n#IP5vS@4d0gZ%ih(O7>Kx4EdKZ5SNH@ueGJ+IfA7e6ygz67&|CiUsLBxD-5c@cuC~9 zWJFXgh`YX^Y+>}qZ#g=A+O5R)SNz{7(+J;keaJRtIR@1I^3a{vtT9AJ1H!qHQ(Z`y zcSER0c0T@PR+7)Wf~@843Ea1YQfG$H06}G2sQJ@$_rA1U3O5gr?$Ea2BxBI3;cedl zy{I@glmEAD+H_OPWLNiFM=#FRpZs}Wa+|}?rdz&CE)y~OP-yRkOPAgi!2`9oOgem; zIwKl_<^nvqEbEoJMV>TKEFDDJ^;g{z@ZMOeNvSzx_!DS)FezGo)M|Rwh{VCe{?u0O zDty+m?ohn%x=^2;~- z{-CI^po@OYI-u3RYuyW007WceZBs18L{@6*W_E}VB8t4hQxI9X?()fZaD`Hnk1w)D zBw~=TPNFjn8Mas1R#jTHC8WG1%oB?aPzGk>XJ_$hZlP|?l+oLoV_nd_3~|gwXh+a8 zY5xg=b~bE6kn9G?9LzT0Hh13NeT77HAX}j<50UsYnWxBMJUi;)g+G4K6TI9vYlHz>OyVPN`C=oikt8 zD0HO9ev-u%CnGAaRs}A1sy~JYd?oRb4T)VM8I3bEp z2e3!}R}h!D@Cna6=l4rbKW25u!ez^{u@yXf?wmd0 z*oJDwfp+Y*&#aQ)xT~wm-L!xXTexIN7IH{SrX6OJp2ev!@M*7TKGYyF9n>10`=$8& zAC=I~)Wi#LP3TR9rzYQ@4|nPEW#w-Z!;F8;P4)|7I+~+*5Q~aWi+fJ^niMkH;p5y1 zH9p<#{7zY2+ADf+zD!E#KU_LHAmBx$b9=6(JnZvJ^7m5x7%zD0Q@IPq80ExZnS2H& zba=p-4ZrSxoUJ;m?W zeD}nohg-9TC)wGGIu&`kEgw4I#SVdN*8lmn#0GWKze7BbK%0TV#pU)#c&$PhiJ^PQ?vd`ysvT;Po3xX$(UFVs*M;+u0_W!b~&4f*ME}6?Wb?;i~dn_eh_t$~Uk1-h^ z9F|vZh#F;Jr2OcIa`(-}?DuU>g@1#9=rGT?Ri!3&oj1Tl8nf3GkiPAj6^%>;hJq$@ ze?Ua`12;R$yQXP!(mw9j!qgOQxfw1k#ZyNE1A=ojr?F?Stj)NLMqOhBpU2$Ciz6Uy z!>v&Dh|OEJ+`xViJXG`J@!D3@bQNNR#_0?93~pg?d-|q#-HR8EvD0m3+6Y9Y0_+gG ztVIz!?O@!n0CgGRQf>rFZY?OjylW ztM9QlkyBtxGXO<619J*2oJ`an8Q#wEVlQCd`e^z%QvE;Lw=nHNiL_yRouWG%rVt+} zB7E4z#ehx``3M6)@tOi4*hqQVc1IRb0MXimrB81$<5 z+tGbAG?c+HzBtEjrU{uq%_x3NBSu_#J?$SnuWszY8-^Erl=_zvZHi0p6Zi0aCzrN0 z^7;Dx`-$NlTb{GMrK~s_ccpRH;@;!_Uk+G~Dr@Tg534Mk;|{?pYP($wU*YQ&mY2Ma z!Hu^YA3l6ocYB;}Mjo3iL9JFKzdriI=6>spdcJc^N^*P_F#NKB}cXyY>wFqPFKZC)anxc#TcIJ{Rz$ ziUJ>eVNnAc2-@LGQ;nJgF586KTTlUNu;R_wtqfx5_HGch>EFnA`Ety;W6|v69b5So(Z~Sf7or_UQ%S81;n$FuEV7$&Vn| zj)&7mKiLdROHIv_ucujrsWO*`Pl6~wvJQ{Da&ceye)^lg>ri*9Z)7}qQ31??pn6pK zjJoYgj0FLL!;qy)+r}R|W;AGe_RnoH6nE;B`vQ^svOT_;Lf7H%Y9tHeHdZ3;~&SP2WGJ2dMFp1!qW<|O-n`Se`LBaZa0zA^qkpB!Q8Ra}v z`de4?^A^wmO2!5P;)pdvzkhD|%fqdwv5ppuQz7n5rHg|b^_@&G-+d&HpC{@k*bR_J zfP#9EjMq6l!Vb#>VyD|BZ=eOX`TI{01S;9n9dF_2z{4b8J4VL_>gbFIgIxf?Ehqxc z1_yN*_Ds}tCE;(DkeGvXC}c#W&kL79a+m!8rtn6 zO5+ioqzbQ~qOuux8yhrnup}qXCIy&dvW+L+#1!f#-RG2G-Zle`)Qb?A%Ybts+;>H z%3@AKgSb-=APlA|0f@1V#0v-&Wy4nf9kF2$j}l~G_5!yHYA8RBGof+6#9{)I(uY3% zoQ7FI+%hq1n{8=nvdL~#2n+}fU+X2#up>jTKiKpBtAWuLZ(R$6Rh`Oit|c!~#EiZQ zT-sGv_hRZ0-DG|{V&R)$bcJ}F@KvySA3rcwL+i*ejCNn8H)5tiyn&Wq_AKh5z?er( z1}fM|F%W?X`;TRhJC&}iK}5dZb6Pbf9fgAAu3k-u7%7T%^^t2;0HRXhn2yqDeHmK( z1!@isssX1p;Oa^lb5qCSjc9n2`(wh(wX6ouwj~SW+4mL)4vbm+;d&Q<-CTr5lZn>8 z!JgF>b@)57fm9L%vVuZBg_I70Y0ALE4JVu{VGR^skOk4_M{W)nZ~nS_aqsKA=Q?+2xHP#V6V4YT}6E)$llh3V9Gdf4Z$f=}0I`59Bu z;sUkG?rr^bKLT3MG8v?;{VL)ECq-6`F`naIRU8o+$!!0%f%~F#%7T~r+;3sN_kx+x zsHqw1aW|MAvb60Ko3U@bgZST^?|h-n8$FBMNgcR_hkTF)yh05r>$EAS-s|W$nKk^?E;h4UGq_W`6ZdX z=Ml?TNJwa?uoQ46ofx6g7QGkd#ufWD=UqOR>+b8>^C~GsoT zTcs9(oHK8GfhX{mB}*_4Cgsg`5@MC9A9 zN4wMr0b%|6^$S5H#Pz24vM{?3c%uS^b??(BDs`AH5>ZtNV^`Onw{h3d){bv)`n9Oe zj(Z(m-FiixDjGbTc}`XIOnlzyf^70m#W7J91JYadsjaI}h~AsFncXYO41^SCQMsr? z%ZMQZ!kfc`R(m&);;<3B_07Ozfzjm2DWsP@ERdGoh`MJ&f_2A^9V6b2 zN1GrN9`cMcJhc3?YQ!8dNDwh76M6lcFjT)ChDgRv2p>C_tffxW2EKvt8uO~^UQ&If z-=n{M9MRABLe@i7M1{v|CdXCN@j1Y#MBXJc}v6Q zoyKScvRa}8xVM-IF@z`jSCpbD?BU}|q#hl+$-J&k#`PR2hd>5}&L*92VZI3R!iztO zUZT~gF3M*=?)Pgzni+jm@T`%_UmK4^`4Gp0GY@yO2zshHa%4x5`BIoOhGFv`J8K&T zTZHJyK63VGa%V*g4AVpuN}J8Ni#Nb6^KmkC2XVaZSzohXQ?X`UaY-G>ERFc`>dl)i zWW~3?)|B55Zz#(?3BS0`34&9NHiTtDQ3>l9p}+}s8}(ZkOjEF}sQ&Rj?la}&l4lpD z_B?fTQ2*da2RJ@&KAE`BbnNKgs;5!m4z1*S2e%Y+{r=WUXb6F8p`?yM#Yn`FxyD~X zqs_A)-pu%U7e689`_)D#?j5w$_M46?A_wUQrTZZDQ_RWW7%#cC;lRH0V2gP8sZ^7D z*tW$QsFa&G>mj-_Upz^PrFo<}x60zivg8K|ynpzs%Dwz+R3^e=M$L}zhT)kE5+3Cq zs~OIds6#0shYD;43M4GPPSht?scz+y1KV!lzlK5ozE&5PN73EOA!eDi5+7k!*_W`U z?co9oPsSy;IINPM3Uz*UL56!y9m{a`zZUq5MwUfL&;Q8d zdY^{gM0a(0dqLdDdkV{-uE4Xa-j(nF^?h#s{hmKX10|my+51)u2a)>kM0yXmq#6mP4O$u5zlL)f2O6T z+R%nwTv4Ko{8<_y#1uWAEIz!i^03;7)%`?L!KoDs>=!ROj;^h!C!k1a=~i?Rg^YdB zs@-Te-*Y2R<7CJ8{Qq)QSf$+gcx2=;77;V{;Gd@BiH_(7O^Y+RwUM_MePgu!{M%*7 z`STFj00_!il#Y`pQ~f4*`InE1sIuh|Ru`DxW{g0trRZVMY>;OM(lYu-M+JU&`leHs z|Ey`o1p||PnI-a*WDtPU13D?!lmUylf=|u?Ps*uN(h)3S?Sqp|BLo(E+WnqIQnNfRc`s@ylkMdJEIJrXr?ap1Mdn4smsK~US(XG&8B3y4`y zYUaHFVHPiGK<{Y3Y|W<1K6NQP7;<`~mUWf7mX_^4p(F5e9G#}zIx0Qp{|su^nT;9@ zcNV=OguXbOSFMY`YC#=Chq)K{7#f)rF`{OXfpEH%P5f;}*DDxKz>(Bqsc#es*+w7{ z_sZk#or%%6&@jzNu%CQB%a;0A?7^IY3#Kny+0SIGofHFP$ELu|+Wj5zEtEG6e|RG^ zbIJ$9o~WZhMGONqX3Aq_Xy!GER+c0t*SzTlL5dwzLU z6T!b_z=Gcm7(^x6NGdF#DcK2@ zHwY&&iGqcx+ZK#*@+Ua&aInIW zIT#R(RJtl>_IQG>-J}kout-Dy>kK${KEQ_LPlrQ1ZzKf>H&6;JksM7Q{cAIGSTbHX zeBHF8Hpvoe#N!LznVeV$moC}%lAhiW)lUkE?2&aRImBr3M0i0qe{_10k&&;FULeDl z53RqA!g7z_J(2WrKgr#dOk%ZUHp{Ul#*WV{S%DU{4XyGxzJnNoqh}Fxq2^<}Tzwkx zV|?p@ScH;8nS5;k5g*uYI2ePBL;ZJlRBAnu`k|8%%RUZH+qL&hg4X8o84-Bcs@X0r zo4ca>flS#r!`Zljcd~?ZpcwFmjbLwqO4VSzXV7;)jJ0fygl7onQ6z}GuUDwPUWcq3 z^LR?NSkN*@N!*6@X~6R6nHTOq;#o1c-({#1a*%Bt;n%!;r{rkomQ^4i z$&6(6&$njr(qy!Q4xUU`#P54uHJ87HP@#{>vDDs=28$~r-_VCe$41vbQ&X@r!uquT zDzFXN#Y2`!r|)rQhJu317QAbY*>||ssF7wpT5X@ERgK{i6O)aDjC&1STIr%OcHNsyWg%*-Mx3OH=iu@LW9a9%1InP(S1q|&~{!ArjB)?oJg|!!N=>U z*Ze>q?XaqO{pleEchKU(;o90JABG1XJor|nJt~);UVr>AU*Lvvon`V>6tHMTy_m$7 zXCUgoqK-qwlLlgN z>tO*MqU=YYOuw}#m|8+A2$7|$IfK0RVX|~PZJ0>MY@PBbDw!lwNA3CG&;Xq)oW=!w z_uWPjV#`lgy=V7sFYcLWl5DSQqYC1Vhuz99dIXQYT( z898FGpzhcf=I?GxVS-P}HjKq&Z3Gp7Kp(w~4=UhELvz}(m3Z8^n^uFJ=qemz4N;)R z_**n?-aMUyEA>a9%8g59U?&Xux$w+CM~Kain22J8JHP3u#e2be!JlrItdNWWwmFVF z`dHxv3@wUQ|FS(zo{0ooFOxpO#|`I96U9`YCyWywOVhHF5*3b z?r{KVQ$%6`1Ln;sb#*MFs6ZViJ1}}01&ra4kxzZNb1_-r!7un-3s{;W-iZSa6oJDW z#L&!Wf4{z~jYHs8(5q}mFC6%At=_Zy`6VG2>fS!Id^LeZW7X~2hwBiRS3Hb9p$%s4 zz}U2QulYNIbx4yl`cLiI>x1vf8e;ftHr1Z>yPZr64etQVw z5%HqQn2=2~1{@$SsfuKe(q2p?Ir5Tf1&}7U9Ut2NCTUe$W*EeH$hD{HD4r za~!J?7&G-}9eam3W?l@9js2{1_7M9OGlD0My=2^?KGR_o3D(ao%$|+Aw1L%}x()8| z`&m>L`PtQi9sz(MsqOeb3!CcZZYZErd89E6Y)nkSMUvI(d7XsyS=XvxH<4>?wU9Aj z!9Mt=H+f#Sl&11#1TUbKJ(=X@$Wa&n(VfpkJIc-yc~({x zD0gH0>T5OyG7yP41?Fe|$Tt(C8HxL+0y|IbK8eb^vZGsnCjuOI@AT6qo!~+c`%Z~i z@g7w)N+BDJWE(YYYVOgv8dQK_FfnN%Khnpzmzc;BW&0+-y~ai$aPEg6ACzCc%4cp! z7U6xIjHquGX-1?Uy#7a7Zx?q0l=Q*<=3Z#Zv9u_Qj{M+Sw?Wk}DTVSzXnBkYr+pNp z`Zun+{pOsTRFZ_II*rS4E;(;5NFv2T#DyfHcQqZj`F`>Jg-A5LG!0o#9t+cG$hV`c ztasb#IDt}z3(egxp7{}>y+7(**rZ)F2@dueBq;1AN+5v61F7xOa20avyvi)GEp zlbR~Z4ahf*5m(CYE8D_o!h~dgWR3^p2K}h6K1yLNs&TK_Yv0NCx32HcrFjsjQ8q^M z;#+TWr+<5(n_zmRG?k+ikF(9^p2sUk&=U_pU?Z!OUav_cmWuxqsAXhs`#64}wEO%r zv7TvBzmC+ve4Sfgfd_VHmh8S`@5Y8Jng_v^wiJLOQ?+ zZU#74SVpJ0HYsRWzy7vo8jOxn9U3m3>u}USZ>!X=Ejx7B!TZ?m|AhA=otGdw>wniQ za(M|rWJA3-{KSMGXv>D|TzS9O=x24sYnS1Y#&~|m8*|Y`|BJ@VwV5s>Zn6`n`qkFX z&iGil{NDe(*jTIan4+q#I07|QVk3YQM0ppMcCa556FQRu&GuD){tFk zV>hK)=O6n07!|y@zf7x>B&#xsmhkJ*?PGZq9R#pM#w|-PZM7C5LZwN(Nw_Dxm^5}< z++?V-x8MMF6#S&4h7yJXb7UtFzpBo6Df^k0)ocI@p&>;3Fm~$yZ_#np82Pk z3j)QS9dGNS`_hcR_nDPMuW8eVcP{0jRCAbTPSlj$9LVI^ukb{-binQ*yVd~jiVc&$ zS>IyzVRvTo=&`%Uyott#0o`)X zPxIlIAB#zTY4)RNDY{)IyKyj2kKTQtN7A&2_2M5O8A~m7 zgyE81L{iAZN7|l|X*d|k3PgLqt`=rYjT3N_MZqv}ZF&BICBkO-O zrf^a*{M0spVodffysi;8Yr)Mef+bv1Qs;8-Bz*k4WXF1C_*}Wo+Gn1jr2r~#4gcf4 zb?esI)@1h{y?VXoH&~P%Qr^dyfQobv|;8)fz&r zfS;7u(w2C|?wkYNZ3x6HRoLfYCUc)n2o^pSVZYPMBLwt9DpDa)Q^za=Fqi6v*5PFR zQAZd;2RfFiB3nISza@gtO9Wf6;Zccy0sDH6=UyrV$d;?R@?96Dx8IgUQ1_7^9zOlZ z08Sr5V&voPMe=ren&LqcPP?U%^O`||LK*{;=(F;wq*;M2nE6!c(DSPv?a)MPIF9_g zqe`kc3i_LQdF`hk*2nf@7+b|Q!#G8Gc5)cnbn2(Y0hyGTGcU|qxb{x;`0cG;5B*VJ zot)-EQ-R)k+eX_iCp&q<7b`47$9ToA>Bty9rO6>Y$*k3pz;m2$_5-EAn)sS)OhA<->sG~zg>LSa zwDZI~K^no}AL_QV_*VNi!+jZ|t=%kb;o+_u>N*B%^J#W12K;Q&kJn8%LHXm-1M8)O zPp3>N`I>ne)LqsJan4f5gmA-|H4(dlvPW)aN-?fG5A*z*LdDjjxT71In8AbQU1LUy=5}sv z%yd^T(Ra`@bdgX<93SWGeCBawF?M)bfj^%k@+2P3a#!*J=%dr5!fj3)nsHtmcvc*A z5OC|~M9kE(QQ?(IKwMpz5r?cL=lm)ogiDg2;J z>dzP;Q?e`G&2u2%?8B-Iw_Q#j#Idk#>{!+ks)xSoV1(u)&H>+$UY6Ps#Y9@~plLIs!4hoe1XJ>gMTlU8AY!YW^7_uVSL@KBG^p!Q zUYC>BlXDR*GE}2PmBKxIIK0>FxpPs1sv>KLtd@0yvJP9I6k3$WWrx0hc=kO!d)8um z`^^wkvSLJ79I=JtVL3K!>SP)_Tbpnp!Wg1;j_ZK)a&4|ox~G#t9RF?d^Fv=2FGJ~{ zGW`VlVgQ9Jwo@c=zGG8B6+u1uwNgf(?_&Aw+IfSpF_g)YQ6A4*(PTKlc{kbm!2f+%OV#>BRmSfYf&l6q~hxYAz8b1t4rf& z?s0g1)!kvQw=HA{t&fVbxDkT;nB2)#*f%f(6b(x#r^ph041@fra$*LlA%Vyh!T|)% zy0A_JN>2~2RuZnNl^+2YS%PN`{#u2|*IvwSrBHKDBCr?(!HfC*1mqu^%&z>;4a-2Fu zwu+J`s`>?2S`U~R6ELUZB!QJm^Ge>ty+#p8ih6Q<8!+V?=oKD${{VaLtJqXK}oM4Ap;91{c$HLrpRZ5fC9n}Ty@AaIv*an*@;<0 zb#)?~Oir^^uAQb^bab?EHlp!v*)r_Hu1dpc0}h;=X)5S#Zj7W8&*%s24Wfx28#4P+ zm-+f{W*(Npm%>2Jk7-C4T^?cVcic?Hg!wqR=qbzQnK)GVd547<~z-|y?6&Ci)Z z(274Wdv%7#JL(FJI_b|xm2O=x`#T%^7=*cupS)UNw(wn*Mjzf57aw1J#BtxV^EsVv z?}9nHbm>x>^(s)c8f9Il;U^*@A~Ih7fWK9E0Vm1|RkdB=nVEMfK^@oC33Kf6?^vCQ zD@#CFK{#!Ie}s;s{n|=Vbd*G^;XkP4JpEL_W4l|asR}RM4Z)EQ|E-ytAFF=Vf)Adw z_N$SVHSu2Z099M7`PG^^C(G2IX3uSrvpp$!x!#Cy5Uqx$(dJEiFazG&G+KpJS9kZL z4*3Wt0rih~UN=#-!omPV*mENnx}DqZ9| zXhzpPlf)B|+T{F!6@utfygr<`i^_UV$ov(=7x!${x8>!v_s#a8(^943laW)A5@2Ix zHg6nAGPRiDoH=TI;fC}*NOg_+-`hP8tmEOE^^cBIKfZ1>1`>%@!{EaRL4VQS$_A0y z$)1<0la@dalMQ`9ii9TMD4XR^fAI1$*HIQn$>=Dt)p6=e@g`Kb6kuE4ESj3VZVoPK1$&W9$+emVWmN12kb ztGkKIywm$3BhRjkD65HxU97R89mT?)y?c*+&OhC>pJwUA{*Cx#*;OS)oGEEW_*sT= z{koH?Z-NN8yiRS*wo!D_GvD7%Di3g7)3@*ZZxjGVc4Gi=6o;Pa`SeYEIhnZz&xlPA zSLXa1Gpt0c4D2IbG+yo1Ia5Eplf&)(jEU zTVJb(CX^@hcy?m$7kxA9c1u;Q&INuRSMq4$5371R4;_Az;gE5y<RN@@G2nc@p)V{^2Xih04=Lm45nd>tdeztFdalXRS} zX`>lM?BT|+IJ>3L zvu2fkZN4t?SieP1PEMz*u9x;ZapCe6hPc4?dRl8O4G2AlWn!#r%$M>{Yn1mw6pBq1 zud#@@+46nwjO*c%(VYV_6RsH!8amYCN@C!ODXuX)z;p^! zV3*P>(2>RI9xLEZ@$=I8d@hh4E*>SNS*OZ7+kISnYR^{8&W z%f_b*09cXmzvBEhP;<^Yr4Q>M`e*^a)>b@)p(Y)PdEMTyUrQXb_cU*yS9ki8PcCC~Bf1{fPXt&1I*9y!3559c zx4MHM0srO0VpI*%JAC=R19pUcN0%!~|dByC`5aJC}h#91IZ7*9ueP;2* zy8^miODb^+w0HVimJ3*$I+)+qW#_jtV&&nV@JNCciQ)iH)iX?^s3Vf4RPVR@W#@Ty z+kj><{}lA}tqTWK>?M#Fr+_e2s(0NN5D*Y&eYWPq%;sZ^kE{jl#gV%cS>#DsNg6}I z!u5%eow7(mlx*F_bZXA0c1lJdMD9vnEEU=Y91~7$;Pi4AZ*BofHyDvEbFQ1T7C!yt ztmdLA!{Qd(fYDL4Kb#nYm1l+Q;28QYwGqurivLbvHN+@MRJh~V-oy;U!-gwS%2G1ZXi*)=k8d~J>PJVTvY7pB z)-URE+^8>0AH01xuI}%QYY&FK^bB8hI4f!5#P{KSk{2CmHSgss`xB=^k4^C~3`UYH zG@3LoveS!c$oNh;uc zRdj7^7(4ZKw%&&MRX=wg?NV@xauqe7k;BtnA6&vbt^7;mKq3 zSNr+;n*OZKFxJU`9}l-GPK6Mjj}9-&Xc^i$x6>*77)0krLhVAhDUZDf^Pz|tyP75C zaz=PdZb(7k7a?J#CQ*m|MkN1wM0WX+!6#_+870{6=(qj%-S*DM|D59lcGJ_dw2Mgy zDf_5wL>n!rDScT?R#3w+B_;Kb>J2XM`FyZxgt;=M?&dc51yfH~g-3-*{xYp08Pp z&CM?l1@UnDsp(PnW>Jr*(*YY#qR~0O$0*;U*`iWTGTapfJ#3sEty?3^V+p)M;X~)C zK|^e!xk%eyeAet++T$<7?&Peg;Fln_+`>IfG`QI%D*wQN$w*g_+0Q?{1S3`x5&A_R34#WL4TeK&|UcM-J#OASbt&g z$Cn+52o^Pw4Rx}xzJyQ{dh*cHjXzr;W&qr7QGe`*p8G1-nAKc47)%&l_fStg5X`E_ z>h-w}^&dlsN!Av);(xLC=21EJZQJlsnj{n=&4VFD2}x8MMWR`Tlq6Itl}a>;5K@Lz zNGVjBh)Oh}fuey14T_Q|Ns~0c+wQvF`+nB<_qV?F{qb1qxu0w3Iy=wf{0;lQO*^iQ zGy02O68?zawFgCsfai>*8gf#+&sD=_V7t_9LRGErW80!>-zK zyKmWIj^{yiE`Y+9u$w^D=l7*c=^4r&JTP9XCTU$oN38ipg(;H zY!z8#lQGJXVH#VvIG(`Phn{oJh((XG?PS!OB{DMi4I7yeHb0d8+}<<5rKKipteEt9a9^*;r-al;ky-zAU37Q=988;!VMn7XjgR0Q)iTI zB>xX?X%8_{s@U@T@dIU))-w>kxp(zc`W%swc4ItLm-Nnc|7f4fCQ4I@^4JP;sG)4IgG4{ho1kv^{f{c^c} z{jN0!b`SvzR0+e5jpD2A*^3}HB}FqDKo@9>f*YkxZ09Q4zB~ftS*8qp*{|nI(pDvc z0xk`gpNSJ&#j8)dZFO{M#_O%atm`^9e**-SsCod5e06}#ib#Fon!CbbY)!0 zG;DTW!@;D$Wr)#_czU0}ra+OK@+4UEHW|3wCxqdqBP|GFKy+Jl1Iio66D|z*0C02y zB3fMNW~H#^okg8;Uu4ooZE;^Lk2tXaaMjt%E0gRON3F|@+Xf0z_b=nWn7Mw;A7^_j z1mbVqj>CcnmqEeHms;KrNnedphcZvWAIe;{9EI-z^G5^o^YcslP$$zW1TVjq{clSy54#T|NHY5|U#Q_z@vhK=xvSO2|c+p4g1 zr9o&Ns^1AmVVVnb#FvHKPAE69ZNH2N*}+*w>^aILdwEgV=R*MPh83X2><0LC8WrIE zkDd){k01(S3f`wCN-QE$Gs3rF=hq1CBPE~zlS9xi`i=3dB{lzu= zkBI5ShK>%At(Hd|dk#v4>6xv_lkSFl$YD3>&I7~?>gwiUnhA9KMnXo(hn@+8J-&Mj zrRSZ8-LWZ$2-XDbUfuZ^@mS9ROjjYBDdQMHxMGU#%f#t0W6r>Kg!rA zt1_qcYt?zmW`C>z%&rcpaOR=<*k0l54PDn$!y&e(Eegt~bJhJ32vgNjVZczd&?9{` z0)`t!k7|Orp_imnI4ov2?jn3Yr@@&e9Fc|_Muad;Gy(UITsPToUp74c6scB#1Jp`y z5*v(J=(^b3{lX$Qu+RX%y+n@0%`TBw_YGEIC=HIy)^%I%;+oU~`SG@H(|~g!-^GEn zUihIirq+ZZ?JW7>&}K+qX(wTOhyV{biLMC<;VELHus~2w*_{Y~Pfg_ZCbVi`q-Y2Y z-g0D*=vutAF3B^|E-O~UdnIsq)qM5L*Z@}5coSmeI8im6j zKryN49z}6XK6UsUO3+7C-Q7G7EjB@&VEHgJ>2$4V3oz-qlYU4zp{^KL37HJ6ynA)> zQeVd>AfSMJI5d+Qb`~;F!wa}Jc-!kTs4)pYMY<|K3Q1BAgAnmgj?uRy2@MJT$zuuw z26BrDq(^=u)XxwV@o#L1854lK8d{^jQWU_nC_@5qtdS)JcQ1^23#?R2Pb4?*Yk0?? z)Uiryc7cX>4bloh8um_8=6<zqrhV*0cnR>@Moc<530&Y zeh8uz{zr8{4&`x+Nqa{O4wn5+6V1lP#yA{vl&eOTW0?G6aDI`;3h?O}6oa7KOuTD5 zyfp?CqGx+0qEtufkP9XuzS_6qVM*>~SJ4u*!B~pE@1%rJBwuY(wu(JS`Yt?=*(l&l zZ_Ip*UqDSh^fXg}Zg~&4hl?dWVbYBq-hyH!ex4QqE=W?uLVQ5OvGDGX2zr?bxUFH` z5p3t^_WbK()Q7m**l@*c+pIGPahWwnpaFqbfMWhensQ@b5*Q`?cO0Zy#HXVPToH!SKjS$0QGqIqo&NvdL46u!x90W+vMdl#ItdKVvN0|-+ zkVy0I(uf^I`YIeBFMu3Z!QxIsC~z?&^fIS*1{dU$0ho$Jq$sfIJ*d+6iXQ)?n`dmY@xB;I@zwQe2}oFKsodZF7AKlN^;RWTlSnGHZU z!axxEljw2-lL=fF9NY>pJ`kXE?WGN+pLB7V^({TGEorn_D*i3tIZC$&g$lP93u-(P z&iRb2Lh6hHq!SrOir@|@!O5y{Sq9uCcOX)-rO$z4eqEIXD*F{ zQ9&#K`ahb)fyGW~hGF5)5wfigx`$}(Hk9`S*JEZB{WkDT2_ozAiM=K@Ix^1^c05`H z5b6H5l>08#gQVdBicU-fHn@zVhIp5)FAnYg-9z1QU(#$IG^A8BqsnM4>j|mJtlO~Y ziZz*tmE9IW{B{!O9_6p23)3<*1CmFHCe?*yo=;1Yr=eCDS!dFSEXMpIR1uWG(}Xz$INE6Werju*0&#P2{(5Y$ z5}Xqcy@y@Fd>=h}G{Nybkhmk=4g5InU*h4?Tn@TOu)9e|j#qeS{gNZ<>qPAaT(Y>} z{>q&X)vti7MgWl%BtD`&7pFTDqJcdCXpjae2b9Me>N_-nD32IC>6fVN%*uDcK(iEM zlR(6@J}L$J4Jr`9g>2~)zOb5d@=@aX zY(ODKj$W&9iw($skoYvdfgphyJH~%p<%UVjda|Lk=h$tTtGjln@Jw0mp0V@8rn{FkUG3V{ly2yTS2axTzX}o}{3qQV8L@zdmtuzL#70s|YGxkLFDkQ$B13!L8RroJ^k+K;dtr5*!3kso zJ`6@LY(qAR90J>SFeeK_f}-;6OrentZs5xAk-N}HP^2za-0ZFmIO?$q7jw)IPPn|b zrl5eq%bbDuh;lX%JfN}Y3N9y!ijEL49F*0I72O6=MKQ#9)ST-_EdP7_krfBl`G-zl zaew0d#{mXnka8o%<=2%+v0NPfa5lamdJIAntY~6pAt)5PYP>q1{4Ee6 zpa8HCz8mj6{!0r#@Fse_81$Aea6`lF!iKYkvS86jm`vlwfHF$YC4=ijBr4L&0o~ht zZ7JD_%vP*J8AF&NitgMG^B-ehkdq5X%fNG%5=hruamG{P0>maP08OFXx93??Tv7t& zQXFx<`V~=87;5zAS$OnQXMimaAZ{ZmJfaL}z!dcn6zPjO9<|Z@iDgFm+KejP9S-52mtT z?{MB={U|G|1fA~%h&Ty;28`>Ul+?AD?=Wy;X>_feP*%a_MAmR{1EioZ&6G{>Tr+If zbMPHlb3_q*f_sg=Zc6__VTKDud)Y?^53=y#)MOIwnwSZKrofISqXLyLv?xKBfLx;O z6ZN%84)2JZEZDy1DyE?G!&Ug{hI3xk2cws$g0NV0U_T+&$KB}%(YFX-d5&|*1-di{ z)u?=ckXj224%+UAi51Vm-c;f|%s3W<+YT2%QB@_l9!g^dtDNehwgH=mE?cN^ z3h`ioItNhb2q4_*2(_EvS~g0K&D-y}D0VU#r$#|d&EOTGL7hn_f_kw)u}?m0ib?1V zC{tg=L{WmqF74EaKrEk$mqaRM+U+bPucu=Tu} zsn%SQYBN({IRn9u;1uIY9W+Zt*uf8Q)XWrot~KBs^M=bvz+@&b1!>Vbhh$Y9%;dSJ zYbIszMs)G8)|WclK{bUIRPo+ZES2&{Cc&=NpYWv_B8&hC^W39nIs;Z9_Znp-S+mww zL=7jNJ%4@*nY)A&hZ!h?JBFLO=U_gQR3^$j0|Q zzo0CF?}-gj6;UAszz}nYx2yw~7EBjd9+PlhQ(HSa21s9q(FJ_uHAlYj z0+(`&%DS$m3W+|Xa1)=}D7IXETf=+oeeYkr+&e*;gE)jry0Aj-L&U9Hsh3?gnQUWX zE2=sMw=9s9965zUo*Rc5Rn(9@oCdj0!;vY~M&;;Xqj`{me6w;yjRbk34Anq5KT!63 z0&RNO)s+Gw``nY8jH||l2r&R?A=9~+7oS)Lj%v^mq|Srq{yduRpv3))Va|dlD^^aD zUlN};?0bk>O*ZFr9-hK%@%2im9A)aZ*V#QhtLDzLW)>gc_e4}R^f^I%oy56I1wM4Y zt->>)OdRMXl_zRCIxrDtFay;rR)vr29)J$RO_&j8?p;8~6%^dgM zgit7Z>~FeV;l2oj3u1+HPX)BuFS^n1{qLVk&1I*0bGNo%Dd^IOz7H)ZC@bn;bar1g z*oH7LBvCAwwJC4tBjP9U*C@iwezCuo#FYwV89LrTkMbPvSBhiuu%>Dw(S(fckPpSm z&7dn{eA+{&t=ifndvJve+)MWUI@AK-@ch$nrs~x7faR4^QOQX-Z+wTeLlg;toD!1C zAv6Fq13OL>rUf39%KX)jGq6x|?%_C{594D9UnFNI-xfG=kXCGw&+63t(4r z#Q7X-)9`v!A()>O`=B=P#;7(Sof^vN&$#SmMj>?ht9w;F?ol;orvb#D^m|&~^70v- zSSj7f%3IZu1PER*oRfotl|j49hRQT!5Cx%FhK1=-dJ4dWZ{zi%t2ZXTxC}1{@Mddt zwm{nO7O6ys{2hX7lE+he3>1{7A!a9k1o{|zm2JO!HX~OLpg#{FBnpIt3>gj5NeF|i zeN~tO@Ft|XxafLkG82A3a=4*5(u5%jWoGn3+Xuh3Qd^*vrCIb;GXp9;{pgX~^EnrB z*?(LA4TU)+Z~#+cVSvKZw=51|3sHLH>4C08M_-?Z#GF6p!7GR;DyHB7{}G<4Q=g+o ztwd9lGXENYekh~DUMv&fjJZmQX|2O_QY%&>S=2`4a5 zwKxHMl^j;+K$6&mSP(Q8ix;Cd5N>f64DyHykq^`B{pi_3`~%|;`Mx^uDb{fwuN}`k zmQi#a0DA#}2O+bn&a8#b7bzZu0tw8Y?^F2@7zP)ZUw~MW%a*11>rvG<(1IU13EhV} zfQ-h*T?;>O?aL@5Sc3E~s7ZB9^-%0rOhKVIunAsKI3_7w44fhdRA&UV0{;-qf)rpm zc02T0!C=YtZ>h}3T)Gp$^>^A=0QaWME%FN>C>g*JCyWMv+1H0tSXsg4yKdpTK0bDx z&?U4ml*t41vl80~$S`}sX8u*{zA~|#ewKaUIwTvJwm9TS)t#K_gacQbQC%Bb{0m;tc|XqpniktB`9>=pa|f`gLbWi-8u zQWZejk(UR4U|JJMzS-R63yJI>SiS$Yfe{e}U{%mOAQ|uAt&VDppg+4Lj=1L4LkG7G z`~3de4qO5(JWV6T!5NqSQ!ARLpHjpDhPO_`n}+aZFCN^^lVdbC2(3L*z<8j)2Bh;6 zVU<{&r^3U-YsVkoHSqynOnFM^(?E8z;WJSShEq!chX_(G5=D~mZ4;CSMHj$@*Ew=x zbRpPLh7jYArX698?#<5lpwi~4H~+LKcNHr)hK*EzUw`LgTrf^N08^|CqA+jn%tiz< zOF~_EMi@okgZkE{hRWXl9AIf{JeSQGqH;MvkZwss*p za4y3Im`D>Uo`yV}ocJ;FAd|)gmn9dl;n{#g79%8VOOq4UWdOo587^=KV2*DPqxA4X zi+6u8&P=n5-90_>4DZdTsElj7y(VzoCZ@pf9(L3lAkXuUA9t&CfbL^MOb_PC_4HVM zsBM~m=%PGi^*5Cd7cdD{dySM%)EKy&6m^&=0$oyDJ0HPNeD z+C~F?Jm(h;BKh3VX-s!UqWa|yA6c$|Uw~4$m{YKTb72{-_1Wflwa%*13;2=jwIL`N z?V~jsTefWS3hp1hGKrh3cVzLRMfOb_yeei0uG-k;qLwl0gAmBphh<^1i?K(ify$XT zl!VVoBPoNU+W~j+rP*S=vr}5{ix--}e!+-_*LIEJ2cFl|m|0k%7$F`#0F*wKuq-KD z#zYS`WC1#F&BJvw3ra*;wCY?88(?8;lr{FF2zliOxHNkR=h#idd|#jEl%fk=0HVA@ zJKW^rK@4)^k@2TF|L*a|6Ew@L>8QUko>y*%y%At4!;drpHgL~j?qVatNF%r+9;O!v zdE_)FCnwFISqlr@0X!SZmZ6%V-^$_-Pk|hwiA<21l3MC}{nj+@Kuw@BLFP3! zBfKy_9`YuFm~X>0NS3{#ieJzZ2K!x6uqnoa2FhpDc8vP;(vP_6iq~ERDXCtx z%$_)xKYfVApeHGL?jtkph7P`mKl3(@5?1Er$DKl3;PJ!6gTeW*3~#&dGSQp82hlAnN!y^^CzA(P3bpm^+?&Q2L*Pu>5KWKdvK=hGj{b9z305<6Jt z#3?}Cl*^A@0k!E15Q=8MbC34M(#zsjrAP*o>{`g&=Yb+*XI>D3nr`*`^c7Hoq?itK zh|TeHeUGVr*dy?D=c;*nc(DuTkN#Z8;!hWm#%;p{pVJ}yhnpVSoMF@fNM5UT|j^ z5u$nU!-LyT((d`&;ND5<L*q?_2Ic_d4e7J(7C`~Nh$8B1vZM<{TgjR!U z_UqngtmY{MDq zSB*h<;K4$FrkHA;7n=3x(QGIairU3y9X7q_Z{L`iQap|rhh#-VQ{K8M{(82SOE9E9 zc|Z)hGuw~xyocFhTukra>pnZvM{>Z? z8I-ZvmIr7Ww;#k<<~ve_w?y{my*b}EoPTW4KCHd;_+ofGtV(8|O8z7lTC*8v8&;cW zqHF5~-#gA$c2KxT@v^YVz8Yx@?d8>y_yv-&u?5OFb6#Q6RxD0?`!bXs(@?5#1&Eg? zg_@v^P%O`g_$R}Qg~2S!G?y0g2*K@-tUIVy?oWt`3$?541YZ9Bi7%>eFVQxwRGNhX zF@9`j6AaqWkg+o)0`t0d`@wuPiKw6L%Mh_tjJcfZ8BOGO!?^*g1>HYoP_}=Ay|it@BZ6Ba`xe|4@Q7F(8czp3@6gUfJuU|P8FRDNMNt zgQJs`!9hD=RADk_=r{rfAd5JIcm>_$LqZKf+)v4(3Nwwn>|B1Gy;> zg5w7r*lq(IF}X*LZF4M@m~SA1)>i4FlTn(hwo%r7K664S%b^d|!P^T+z8QDQhg z=E0lh>*Pmjzdsr^Z5&43sf`N0<{EAlF$N3uB#NJ;6i+v|DKVB6$a`3ihe!JV+DmT6 zmPt5bU$ytmlv$69N#$g$9K&aXi(64HeB%OigETLhvItympwd5RF@Y}jVIS(SEy0aM z^rO%z#uvng{6^4*-;72kgOCw|mYZR~u~j|ULZv5m4oC8;;n1$z-H%vz)Y-#Ji%{HK zwT|cX@IIj|x$6?u=RAgF51LPcx)U8JD}(Enc0iX4nk$9@x2=;}QZA3h&Fna#vjq|& zaF8++o)#9uxX`Y2DPSD7G)Ff6IjEH?Sa^xt$7(%^tH{!}#vCGH=8NEhWn6(i`IucQ z3SNC~eo_M(*7K9=moDXD;3GYa3fWFIdQquQl35Z*koIMV9d5?W9Tfitj70J04U8L6 z9~3pe^&aIttfpn=Uh7RYRF7lH1Vg<{mE#(ya*I2(G*mese>{o$nThe3+=6}(QlKIT zR>g^ zkEYDh)vL?hk5hf~E*UilNtGabAXEqxUI82tB)vf5$wNr@C?XU;FuXPu#m*148V_H* zcSX6ZtSn9jVh@SBLFLBxr4H_-j&Z2fZ_5i1XI|YO-GE0^g6`LVx?j^Tk8^oB(>Kgu znM^Y6e45tB{?JtBZ3b&6ECZ0$IZQwJMfZ1Udb;_YMKEC@+hzTiW71Zk9hs@(Y7QFj z?>4D#tA{Heu?$lfjE&Gj1#zpy0LIQIDp?LuH^7DCQwl)`lvq4SrAP2p+>QOqXa(wwtAzS}HS{1{$xjXBVcxG)& zG}hty)iLledP>93d?d2X#jB9q@!#Pmd*Co+{GkNh3+Db{&Hb=aHArz-86vT|Rfvj3 z!%LkVl=XX15>o~jLjz|1U=J7FYEoX3AB_uX=Q>Adq4_TT#G)mcJ}!P6(7T`_UqZlu zfiS_4IucMrv2o=k@3Xc>e~fP^uiF-~&ExSK2cNB_4uiNQQ5L9F^gATA%;qzyB0y!t z+?Zo*4!2IKR3ysXxOCND`K4)@;+zN)HcgMOdy|qrsUXA(OyaHE?(#hcEu+n&A1e&s z-JE@5GHkb3O4QAp;``p}-t5Z9r|8_yn zN|sT)ENf`kPs2_1)8wUZqgbzbV2Yu=%4$VahDjT2g5jlvFN7{?#6PLAd@ z2r67xOET8T{p35`cel1HcBv;Fi~IJmF;wtAX(-&!9%j({F3T;PqwW(t?INpQ$ zSCaG{A_0!#{P4z7@4<2RI8VJYGp$c8-)A(lc(6gbV!o(oYDS&2Q=OcWmruv~FZDN~ zpxuVqX$nR`!JfjRb(+U$+S<8;Zp^D=EEEp16uEjhu&^Ha;39nJJATXs^6h3)<^u3~ zy;A2svKbMma}r;LB(DkA?IO^r)Q<**Z8AH;5N}@^% z7BXMZZ%YBE49#=8bN8-Y(;~!$V)>iZTbk4GYCH}KSW^6!oa2cUzQl(-4FL<-7v>K= z>+?0cm|A8BkFQ`jpwIc}=Q){P-#2ojKG&~6Plh-V=;am5hc4$PD`P$@C-VT^q<8Q} zN*n+{h$Mp8-W^vXA`%AYhlS_l^^5iWc$)~_FO+&`U@J$CH@ST8-{F<(>`>ndh8``E z1ap?fjKLkXoK2&0?NcKPASq|RF@CguNzY`C@#@)h9PLNNC-}&2pC=^D54Bu>59f!u z+o&HO-`=v;$7AGiznh6+_n!coH3}CKDWNfv*}!NXcFaHlnC@{0Y%WAQve6_v)#I6Z zon@${C}x3^A+bK>o~>&E-zeCAMdDV%W#c|j9F&7zIOl@$768F`tEBv+(de$_P<~hg z$=^A&j0-0w0Pavd<_De?`KQSNhDzi}SM3&Cv35Uatk{W(iPk<;XEclrfc?8;YV8Yp z;XoPw z5~h`iabU%8K=M%od?>~kM_wp&Gk4=9@SJyk>L5W5X)th(k!RMPB?v_Wm697bZiHw} zDsS}XGjEX=)He7)>ZI z^;{W@=0G)g0Bvar>)eQfhYO-O#M1~vr@kZ=qg0FVbVG3CJ1mU@6o?*~lcB&4y~U?- zS_XBk`u=YsI4;Y-xnunSG87UC{r{QLvrnP(wgwi&_|IRt{}all)~ASY;z2jvQ(a!# zGX^b~;zbVj0l(k&#!1_jky{L^U7Rw7$YWVnzL|T@gc}k{siHcamX@a1UXB%m)GUlma5Xjt~_diNZE|;~FuFZeUQyOMZ$=IbR8qty}K7g1x$9ve= zHcFi9*We&|#3i^`fdHaWlSrjDv!klqvZH+K|6Pr7$?sdht6=4>?HPIqba5K0K3H8i zC;}ZCl0@4eFT^=ajC9E_wRrO2tRAZB>{(~eatkml{twV`T4dc&bhw7D9E+>g5x)tD zz+b}~JPlKsQz&-ZkIj&mmnZO&VFT<>z$4)N$J>js!MM@FIOano$GvFLBA`d;%@TG$ z%QVEqC}1mi&t!4!8pal1BD%gq8DEmT?>&^v^ZpD%teO0OQBg?fd0|150$uknOo;E+ zXx)BPEJOvc_=6}oh4OSNZh+Z@s9^_L<7~ojtgqkU*a9H-d~^ov>zZ)e2ck{oo;G)B z<;eVWr1D&0)T!94gcPe+1tIQ)cQER}2VcNFtN1hjK!qrE+IT6z1t?NaS7L(5ZO9-4 zP&;qW{BcXo=e|p}8vIp(P=B2ris&0r!~`v5qM2mZfgMt`oY&s6Dy`@IBB;)1r(FS- zKutXocMYN32{_se(=|5|6^{Ajp`bbMC+3I=uM)S)5HE_?m`%aP46pC_pn1yBA)}8y z(8A6s+XWo>`P;W!D20yUq6oudTLU#EyvGOYTJs>*<0HTc%>W_1&}Q!JNy5c>8&mKHiz;hIraaOvrPw0vb$TAh7mZhKPQ4@hU%Hgm~!6r|u46J@LQ> zq4$?TclXZ&%q`!7E)*=S_(ZPI>O;~fH;hw$jHn!U8+F3TZq+o2Ousp}42#^wx|CzS ziv|Y;@eQq9v0@ek!E$h!sDp;(3s=Ik*r=yBi%o+Uxgu8*H)D|MMxG^{ zgECtuZtVu(e;Bt~N$5BEhg*ehuU`;V6-2>yurZSuJMvA^ae`;-543uX z%&T|r(hJt7p%_HZyByPvs9~pa4p8xO_%qh?w4u=3acwn(e~w$^fmv4V$-plcqE`KHyUYV9PVgcdx+% zKshLcT0$KmMF1^BurDb-6LSlTIlZ6ee0R%tMuU`(iSh&npaLh8Jx+;}kZA*(zc*-s zyBP2CCFqX7zg0(vV*Uw_AayHZ%8k47pp&H2m~4MIWIls{#fl?J3ytb@h22SuAXCI? zT!arqI4*i~*h4-6SHN(ILIl1A!uDi6_NP~7G%_OM6MS6ce;_#m?jM?+OmiH_0D))u z{cvCnU}zrP${L_qXeGf5MK_}3oaDR;OGg>=)C>I`3%YDpF)R$TNn69T;wAbwT&UWp zy$ktH2M!CvQ&#{N4lM{$_M?{}g)g>NB^qT)7baU+#62>l04@Jkn{uNz+m%Wp5}$(- zh0l!=>6=*sWj`vv^D!5`Wi7xalxgtL>88(@m^}Am3w;358ts>w{|_zi(p9_cVG~H< z;kDZ|a6-HQ5a0qcCr!je&rEg!&;eln@g8SV`Hj^m_s(?|DzaG`uM}1TAym7L@ApGv zqn6|)aUhiX`F*#MCPZU!_Fzq-+VVId{AfAJk5Rg70E&Y9&ak>jVtS@?%SBvPdvVR< zq}MT;$63&EJpa9)0TkH}aQ;)FbF(c=9K-(RpfIKYVtn(#-t5k#hy2!G;4u?I1gAe9 z0cR6h*IzURV3aF;c(<=_SVF^2U;gHXo9}@~_f8m$ zY2E;)LmV=i7xeAX9vEeS*Rig5f>M^;G|+_H#<|b#>rJo-{qaFGoVIBJi3f0hgT{*k z!-IQ|rn&va8)^DQ-uUlENx6lWeuCF`iz>TvmOBKLFZ3Cf)kRd9&>!RHC~S>bx{%gM z)>fQv^q{~%QHmv6|K&(Rrmt~U(4BO~3!43vzX zJz~3O4-=_>&tl@YHKqq%QsER<(x?n|D{&C3y2pL>iBJ z*^J-ByO(;W-nTIc)71aRx68&LR27Mp2vX3&Qf1SS?v%7! z5MIVYbE|YLW{q6M11>BnA?u`l7w$wB2Jslj#H?h(Fo~PN(E$BJpy|#dZ9Mb;r}#Kl zGW^n|6S$JFV*qS#TyR-(dd!y%JLH12jSCmvHi+y~tJPj~#phRVyQ?yKW`x)B3JFaE z69oB)-!GRzTGW647|nn(?db0xk;U#fO0?D0{mV-lfG9#7bP;9w3$!ih>HfN}0AleQ z+Rml`B>toOdsl#gGdOijcKvK^jqO)JUw8&2Acem|c_sp32#v2I z?Ks^XI9aBp7SMB!&PJ<-UKCkK=T4|dir`_Oev3u33Hz!9%GUGAOaG(ewxhYdxRJVY zCt=`HfPsWA5LBnAG+7y#+qlt$ynw|EzO*QV%sMqZAtUSAY%Io!SD$Z)^2vMX%3c<$qS}IYE>A|H|9#u%k}^^G{gnhG{jZ zN?`H{R2lNaMB+<5cCtjHx#ANMiRw_`4v@25SURmoigui9IMPQ@2}(wl zbeZ6Xl1mlaVm(je11{g_~aNh6>-M{eFQE~*9+0xF|gM)CmS_7kmotRX6 z422jW#dIB*#PyxoZqE$3N7UxAvAAZd)cw}L-r?a<+j=l`iP`rn64G-v4(0_!H+SWj&@!5q*~F1V3mF@OUq1o(ncxSU2P~HJKF?Jt&e| zGtokGqTfD)stm0k8|c)EjHM_!;2_gPeiD2P(*f%vzlcJmUA6SrL-jnDH&$M5A(6TeIS4bQ44K0OLfACF(K@)Mo9q|x<` zhILEUKu_W5#+86NuwPtOsoKSt^(^O4Mc5Vo{?)$pSh;iM151)zqtW})Wy+*)Xy#yN zr@0D{NK%&d_7G#I7Q{-BNvZoNiAJ&_;lu_YN8%88z1okdLK{UYBFalcJMaR9|0Ph1 zeiOr-DBw!_AKA*c#E&VxH!TOcu@|LAps`b={6`C|O{+~P}2fAnYy1AezX+YxW&Fc=JMuY|X42{)e^ zDTlXpl;T;0?Xf9$-y&4KA-;ACceW|*zi3XNInA8z&ggMO641<32wR5Gs4C{VD@>SucpIuVb;&b%@5;Zb z6Uo>YhX*^&7~T%w*TJr$LI?-Sa7SK>Orv+|jL}=8F@y9z0(EX17ASq<-e52~C3k~< z(KkCk{B7ytaE@m5M9jd9(@RbUdu`gdk!BwJEq0N<0l*=CtNV3uWi~(zpd&*+&khAj zAZoD0C}*_0g!ExH8OE|*1kbkz@7CU}e1nb-+2#`vFAb!MgH|PSEiqixs9p|;80swz zoP34d6y<{%my#7L{1E74U-tw@1N}#o_Jjj14on>GBO1f3xL?i-jnV0uVl!~5&^|*$ zgb5Q3y3zCcOp1_W1qDALHq3zt3i{}fqS*t~_VTJoYa5kuqdkpl6PjiE`+NyG5rwih4j4pe14({WEga{r)ic7zAUt3FF>2o z@N(J@AmG-D$SNrCp_^N-q{PWk;0ehbQO;ikhser6vKNArIN^;1tl-o3*u5<2vUoZ`BB`wgwCxJvIXzxNw;EuDW6&Tql&u4VWa>3 z7O(({fu+*;P@3h0H4KZn7OD~Ylt;c*YIN8T!j6({JID}9$ON23-+BJ{Zy_v~A-?rG zpTGUAn)N*ErR(m$+l1FQB9n{n*nWev-eOm}MuugTY(AF199Fr{a6W6@uIOjne8Zk5 zNfSV!UVD)8@Mw|;$=DH0v}FG)!%nPSNK9kXlkqdhqSoOVGguWAJIcN#_DwhB%h19T&3G4Oo>r@50ffza^E?XP~3Dd0)3-*Dt4V zUUwe=ACL`CvP##-m&Dmhv2lD1=*gB9tn+%jbnFr<4xWLYIef#SX_Ai*Z2N`|}K|!;GeIT062K){UkShI+IBuCnKQ3I2j5@x{qCHZfik|z>elI8R2XQyJ zZEL=}eKAG45r(|Y!nhW=3Rbdd6yUTbEi|Gx)E%2QvRX4gZZ2BGtn0>i-q1-?epPz_kQK$RHYRw zw1Zk#53AVMa+L(!dvnhPUrmO;jfOGD7vR8&nOn1$u3nv4P@V-NJF!Io2Z{j^@^Er; zV)f8;2jp%mW0cU7Zb*}>a7hpYMZrFx`$%#~qhmnfaE2<%$3vY&VF1=|3NeApc#g8c&y>V6|dg=RrSurBI|}UYxta9PYx&C!)P=wNh=RKChy6B=P?WLw1FrV%5478D!ccd-nS?uBv7C zB=7Y|-Z^G*3SqJ@pFB7!iQODc8YnAkuS$3X( z4%b>-#;zVt=%DCC0QqF>4~wxumK)(zkcQEfx>m#ibvl9OlcV4%sQiSxyC_>IY_RBu z)O;jCkhUev22-%45IJibiC|dlI*5fOEW>9(+-JP;?Up^+N2K?TjSo|SDnVQ;Ckjz=i=e!ILAjB6a%k^p zGA9*t6eUepB?gaf@@&A4A*VU?31svEjTwXxEgAy;H+b?RE=i(EaFG(q?tl00LdXwj zSS%HbsDoLU3qndWfV=< zx&t#;?+?d=w{x_x?$en#{poL;$LEvElb%t6ehSZCO$TEIL{kLY*SU9=JjN_)`Yr0) zhEWGo1_7Ct01(!6Mf?-gFV5Aub>Sei&M zvV_vz;t4OB$3PSSNNV^SRRP?WYyer2V1GeIlO!4JD5dvIs*~{>C8ZB!-|@@9zn#KBe`s;AbHxcjX8xWBi+4(zvjbY=vr} zm(r5OK0o6ouS0#0xMv9D(-Y43Z8E-B0L(Ss$g(1H0M15eU)XttOFDACgo?Fe< zZeTqb?D~weuKQPUgUA_%>6*7it6MN^+FkvDVe7x>TwK$`;BckdiuJn@)p`nX8Mx@r zJ(cz{F8>^%nxSH}6;sJ@&tvXp+av*F>TPR%FjbO8J@t8KNd4(g|7_e1YxH;~hhnyV z^jlIletFW*sbq{Jaus_#&kb|k&J&xDal%U0`!%{N(mm6KbWV))40~$N|7NAt$fcWT5Am0em#fEs)UrJ-)&ztb{ST^dOv;Y zddF&qfm;^0f_(EQPK*~_W|_u>ep>S7$)DA1YWZpfKO?V9@N1OheEq0j?{ngxM1RHdEPqJ5_LN%Z76(QE}F6G`2wdgM2Gxx85_SZRA z?9Tb)%fXhkwO^EO9TXmqNbge7A6x>UbyBAcBb`d0tl*0*<*q)2Nep7Q)QS?4iO ze~jatkb$D8d|DM$lVYXZtGQu;+b4V0AMf;``=|uZ{A$^{uG?Fj;^bJqDN34n=@0*Y zq^>QzxAepJ$5xh7)+>MJYnSG!-}#weI4tt_Npkb*{$nK`&jQpH%@Y_?kg62&AoJOD z%U!PwSC5Nb$l1kVW#$?9PxZH}|E%^(8*$|=T`n1ScLi$4D%#pdWzCN;Sux*VyUQg$ ze3Q=Al#D+urBxyEQ6|6Y60^=uesO}*EjbP0PDeaFX*ML5`*p4bA~@bCKqT$oD_9Uy9yS?rar3On6niGx z%uy;Zg?Z85`{c0q`itwm{JVD_f9^WG{7!~T%^G8cw$drxoOPF-%p|x3YMQy5!?qM5 z_~n*dtnp<ew*33_;2Z;wQKr5Vo_uNZ}dCo5z`hJNTPLM4Zv zn>Tlbh36{c#-6Kg+zT#mJet47aC79^?mD5PHWC_6&)=KB(+@6lyV3CuIF~?1>RrWU zjw4kL4IKq4v)G2q4(@9{n(Se5GUdP}ueP9Ko3DzOSEgNhkR4W%V76@Ioy(fjIojuX zPjvoRwq0Iq;x*fzzNhol7VH_d;wtoc(=~)4e-8eZcOg_Eh*U76(V8rgtLAu zZDMvE=yN34xSxvq`kS4e5a0N5S}3P>Pc{{&9?w!td*&&$yVQKcq@eqrJ{_sx=l3vL zWT$ZQq09V--pp^3W!Bxlal7uMskG!?mNEGU;)82X{kO>36dj^W^dqiK@P4YXx%t z=zEB6aDUY4ugMI=7OhYy5OS_D-xPA-ow4-0TP*^w&oAF7=RMKYs+0U@oHys#s`A0M zhS|SfzxA@0<;ql?n=QWe{;buq)^~DKydJICZ?5c@S!}Z=GwD?Jh~whhA3a`-bRBP; z%Jnu$YuGhPMQmPu2!GO5tKNo%r9R33@D(ym0&6aB-tIjT@wPp!h9~Rkn`VzB*`xz( z2D7Yfrpk*o9lyR{o!k;n@8p(*cP-wh!gim3C0e&AXt}Xdw0#O!uwE#S`Ka*m3x#Iq zgIk1H#Uu)Lc}O1%-j=R65WdsyGE3q^?t#{oWj?HzN5_)9x#wD5y%n$E-lk#mAwgP* zO+!Lp%C&_Kt8Hxf1A}|F`_5VyZykD*f1T&9nYk;AAMfpO%DGf_A}%+~V@c9bg9uyM zL;k`|W-l7J15#8gk~6=CUAfY~xTLv%v6AR6$KWU_nfCt1X*XZ~&3{SwW<R+~==X*y`JrX0Ckp^oVdJQI##%nlh_re^>foe(uY7S*odw;rg0sT{$UPCq9~# zi{#3wJW5zIr|oXEl(@-4nR6Ap^7Yv6$gW^*dGk4p>(%MQBGtE!zAW9YZKPeArtrN* znqNBIch=RSDHqDjMO!%37c1Wit5Hqii18n{4qaB}eVIjiLCNDcLy0vV_UA2%Zr~9# zJah5jeC~RP-NmjXZDZzh_wcC!B4$Phdwd5y+bcqTmgZWd-P9j#J3G+-N+9X;^^Lo- zIh{OPW0uU1xO=1JchjnQHqm>0N_`F;I(YX*Bk$Ox$aytC;lPO@sqAAb)$%`nsgQ50 zab9OP_d|rdP0{_8?colucHDI7UJ(=jLjA6{@AQ@bE*WoHGSXdLA+o*il)m(+Zh4Wd zJM(bsgve4gA0DOX1u@BK>rbDl(a9IIhKy!gV@jtE-%IbFMnwY~<$NajThnZL^ZR>U z^i{QU@@#XzWOk?SO1pm{`O`&C%Nk$D7iK0s3%LJ-t$=3=?^SxYR6kwj#VxfA7@<6!e?#eH*fQQUGe;v+tg$C zuYBw%6|~%@wdU>Z^Od`EXEKK!Od6LSV)I{j;-`P{u}|Dg-I&?-#YVaDs-9Dx7RFcp zSR3WwDC<7d=F<2!eK4s$BKOhc%J8J4CX<3wy2GA%n8`RDeLh7fNqXVQ*x!d=-dJ56 zHY}X&d6cig&sopOCaJ-_s#D-~W#zWe;Z;&M%bk*pp2_m7HZ2ufdD}0-|HD@?O+VS% zKii#5Z4G1w!2^BZ;TDw<%~=}voCGE^-Bqki$_H#vQs|J{HzhXzaBirwFZ@TQ*2a{WwlLX&F0JfTPJvb2G6K7F;e(EEY!U$tf_ayGv@4RmyM^+ z1?ydV{`^_i8&7GjgBIyLGxqtatPJv*;Uh6e?=H6%$6z7vz$&?mc`nntx5;g|Zd`Dy zPqesX`t9=az9qa&(us;fF4fClR?gXBP*dA@L#@wP&zbF0=Ci?2Ia{*}y$%(=%a&f7 zw`?bydCw-!i=IAVM;jX~&b6C7*mHn&Z&%2aTXq-9+}>pxM-T8!)yfMzuuPTt-mz=U zMex>`v-0Eco_4*W8(%ARp6#r?*R6Wq*4Ka~!S=qw6@`Nxo5a6I>YVNPt*5Hu)pH=J zgUkQR7OPz~sofVFE*IChF!VtFl6g@tKW0cC(j@^|3^x|fIG)b8TF8vpT*w^JTn zn6#4aTzq_n(}+w8=cMwJ`O#NyjJaxDiVKJwTUBxQ=TDpZF6ZY#oQCajdh2ve*{^fF zI^FX7^d2iQH@WfA?#0n*(M5-M$VJpnKAqp+1q-YR3;MUNH4*YH9bC zbNw+M`Tb{6nbot7Hi;HasT(%-VlzIzQ%!IDKEKk0CC2#ro6oPfb*UhZ)jV@TD*KeL z;)@6EojOS@mPQ8hh4)kxrwMHtAL+5PFgSW)isvnJz4d%e%H3WsoyB}Ws;835{>g~4A`5FNXoG%h|9by z;hV*Kc5|fcT6UK4`+<9RbeN}{j`z7=_o8dp>g}q!K1sC!U;mgTS2V3Km-KP_$M zxBY`&+EX1Ei$w&1y**{XRe`z&jvZwre z+y|5L>7JH`vDH^Nk2*f*9%%I8W@)+bz#&yVd-W5cq@%HpO)Ggu6gd*}#q?%YnCg8? z{aDgHp<_FAykxU_>#yPcA@em=+GjS#DEl?q`@idMU3uZ!lkdB#H)M46T|UOM?)a3a z=RMzOZZ$jTb>35#8@oY?s|Bz8UyuxzXxZrU&jP*H}J1 z`Nz7SiM-*|$stjB^9gtEhQ-+}Z}=lML&Xqi+A+x^l*G1iLdE%yj7##DeglquGo7-wM~Jh_x+<*DA6&G`G)?i8=_Rx?du^c}yR-)8hu`;lTa)RvAz4`U(D=of#aOzGH$}muD@^QkpKbWue2WZDLBl+}%o=!q@I_ zelPQ?rRm|e@rZ~Kp6{?YkPu{(t)30KqXx&b;h0OnV%3fteMzTs|mc2)ckQI`h>`mEwkL&qD=lA;E ze*avz>vp@&zdoPytb^nIevRjNJnxTMip*`Z=h9PYLgsI@v@*XnQC^e^)&&~r*b$7zdDRQa=W^7}}!6Rl}^F>PzXwZChwodN>?qZAOciO0 zcU;7+eBY5^wdrvpF*+gcZct$Zwv738f@}2K#X{yoA-ZOXiH)nKqaUo&{LMNdC2{q| z8ozl__S~%A#AM3t9?p(a5Y%w~TdK-(PsH0K(64cBq$1JCU31LrMkmAKi1WFIe%yvW zo2Fg|yzOUr!}))I#gg=Nm-rZ)uV|{KNm|U4^`(4x)i)c>-t<=W{!pnm#tnkb1Xu3) zMe*|wErqc2mDwlEdgg9l^Oh5vjUAUyjh1@X9yyk*l{;GK6ngcsIzwN@3b9Sex4K#C z)}WDjkup*Yo{l%&CjI-0pH>+|HBYZx^9mS($qEAn&O$X)llrQn)y(3Vq9HR$wpOMvuK+r zbyY@uTeqOb%&(#B%u4=pNcU{rOOfD27cP}j8~BZ7yCX`Q70+8gPAp=~RVhjPLuDy> zSMjM&p?-Ewy8p^gwca#>fSk}o-f@ODHuiN#nVCa^UZmgdO;8UXrr{Sdc!2FYSwtOr z?Yi27C%fr;1Lz^%Kruq_AC9}lRM5(=Tbw@1Mi4M^@stN{l&U*b_`LIlU6tylhC8DX z*Yn@$y50X|Oe*!1^cZofdFP3C+1}v9SXGv3(F`4`BZ=SQ49&1TPbl3(EFS2q{-{5b z8Iod@Kar!m%u|ak^i!mlYipRy*>s-c@7$7>zEtwX^iT0Ty@_rqv08fzrlniw!`9oo z?H|<%?c!B&u~wSKwxw`j$K9zq&aplIywmQ(bpHbh>mSxE8t%WGeTQ_K775HXA802_ zIMt@v8mHh=wQwF6+HSG>p2wz5_U!WHz1la!)x5T($pKacmKAEG?2M+>HwPgv7$#{l?&mNP0%{$t;~RUJh*pmLAGMM&Y@vokJioulljjxpnM^|~Ylm(|Tf zUlS0O_-q8q(Xg4>5BJ<8P0;-E-tz3Z&Vj{a4}@gntTJz51qq%H;WYKW3?;eR{;M(e zNrm*f^?EC@(mBoDlGy$yxI@#zxtMzX-z`x)iu&T$G0ld^q-Jl+;k9BIiF-F1!-f#Q)??Qe( zC2|~}QNkj`Lh+*$-8*T?Vg>Ax>ZZ(`uA`Hmmt|r;QWCT~G*3m8ruTZJ(qS(P4ZnM# zIPVpjZ8@WH<( zS2*GLf(U+2-hz$&5SQzsXs6codopKKq?iLjv$pIrQWuV?`+Yf-b@4))sa$yI$``nw z6Xc?UiuH!W`kihEhEtqo2W+<%E$+uwaSFw5n8M3lcRc02ucFssBV=}FfOE6eeYBx$=SA&$MiXe0z`%3Ftwq{NzK=o-Z9wxPB+CqpD7_ap|~*Ly`(k*vyh^uzO|52hqyTwKsjN`pyeF zoxVR_W?v<6HQ(|2w5nU1)uMH&>+b|#Ww%|F#F9$P({QHW+P@ZLj=6>^mU~2;Da%f! zsCiy$IB9WP^hJzHGcS#Ua^;zRr$A!KMvfG{r#d$)a4!|G-SJIV&`j+XS9mVTb!}mM z`?d*nKzPtPMka@DTef+s;ayUP%<4ADA9{&bJyy4R=8rQ>EMpH{n*Mc?D!Jr5>=GR@ zthQdFK9)Xh)7iK{>%8C7boEARi5t&n^x``6wZ%0_g`#{MB^NStRRUXM{;moF*n+Gk$h~b(gE6QFi;A8hnrH1ce!Vt6-rUKRgClLYu z^4!#$_3dDoDL}3WYjYK(d()O&qNBXR;dYrFw`Fa?S1n?SwOXFpC0kiB70R3dE*I9t4cHcZDHlfY#_oB8$0D8C#<6EbNPZW9ZPcN1=FP+S(sF+}y7!&xiVw$f-QZR0s*U@m{ zh=^8n$D>D<-rnt@nV&-#gp7?#C-m&njUKSz;fsb`-8B%gZltjI;y^~}my%r)CnmaW zOJpj~Ri5rb7@h9RYxyO|S1jn9vrckW$1joNLCWNM@{XT>=q0!-T+pwQ!o`(H@7E}u zOqX@wd}I6coMqU~3CDH|i#J8=p;z-436#?9$gGGb3@Nw@GuLn_-U&aCR6Ox8(TAut zQX-R})<3deCA0snoZn!-_iC%Q^%yqD)I^4j3^T<#>ps5B`>S_8bjMq@ZlWYu=;dy{ zvXN6$+=A+pWsdo$Hu4@f3}-oCte<+&kZxb5RKk+I)j@rCjW2<~!Mlw|yLfBMR3|){ zB-}>Zbo(uZ9N8q-w)vNmnQimGgVM_{kKeB{ihRJD_|5Q8veM;?S@;%S{um`T=56~U zZwf+FzxQZL;(H8!dJ~rDozda`x+9%Vk&7XpS(7ut+dn5Z{nw+4=N}gyIPMl5i3*KT z>DlD|Rr2ROouHPO5wlQU@(|;?&1;DdDmcCGYhwECew*1wt9Ex18K8r`0MY$*?8Wv|Iw-sI)EcBmmqPU)-a(GA*<4bhXIlbjc9 zF!gx`p-0)D-Yq83`&sx~FLV5Ga^Be}y@E!p-}yT7cWNZdK_Z zCyBJwg(>?Bcw#$TRgNsojSs$@oTqw|AA3@SoE-t2zqh||kZ9Qv_Gs&0k#Sw+J;m%- zR-$<*q)sd z#v1%t(3}WdGSVOcZ*{Mf^I+tvQ(_RW>wA3c7EYrDf5| zW5L`54HFum0YZ`=$W~&Md!eaJiPN}I+ z*wQvq&6z!I=D^ylzU|n(4Y- zjQ541=tK?I48hCVkv54lH?p)6lOR&6?>*$+VyT%y@zcz)bzO{JzqiF7o4TmF+o7H_ zJykchu$Wd?G@SbU(_j9_Vw|qJ*9nGBJGEQd1=|N0N?+O@2o_$L%NS`~)F&gqRjnhl zR({>6QSRvyQP^WrJO&%f?b}U`^$$}Y*muaylvgWF>KSeH` zp`;~@c*5wm0n zFlj%*tKD(EkHvXfszfvNb;Jj`p-riHRUe8DLtJ^N%Sg@i`A3|3vl#xD{x|`%laulX zJD!Zb;P;1a`l|96O)hKmG8`8&yxe&&@eJKW3kO|z%G)1AQqnJEl|_F`o=hvyB41Nq zKgHpoZ@(icl}AZ0E8itR@oY#AmnVPudgo5>+JnejK8nspwnfyq{J&B}a*qewC9B${ zgnq5n=<43+F_q9w3ZNCr(BBSSq^O@r-nLdXC{g3V9Gl}dpO1+6;Q3^RqQ{h#KB(GJ zE1AthcEF73=bcObdR0O94Y``$nhkAW-06vbB=E8)7k@HS!%8Z8B;!)yj62nn4b}ik_2Gcb6U?t4 zvxFv8ygIQkzVhJ<-RD5;=~AJB@7txe8ZCW?AL%7lb)4_#-d*K)u=s6IpnH~EXdroX z)?TiLGMnv>B&K8}Ut1-!G66f*VJV_3mG;@1;<(#`(D#RZyxu>(=j-HI7~gZ`hh`xr zt~gFba<{i~uR{4qz|@t7_vjfu?%Rba_^TPji*~AR41|CBOQise=mwjqOGBU#X4yfH z)ydSvC*ZPWm!?}vhivX(h|(H7kdS=?%$;C66t7(?2S(cuy2YiX=RoC|qH1kvv4&|u zmQV%IlVC7lcjXU)9Rrw-umiyd4E5c!NCw%f`g6}w7h`E{gry0;>u>}!*XKAJ3$DTF zdIPm)aQZeC?l%bPkdLFT^D;Tea{@qaeCEBY5rlaa@R`{tYy(9POnPJPZMo`|m%Y;Q z)%1+7+d7}Eb15{+Ohi6cjdR>bFhj(UeXHKGUR6G#mlxHT`0l*8SgsJl1JkV7akaoe($|BrJE~QaJR%qw^17#?qIfLXSygSo}NS7cUBA+ z6I|MNY`Q4&Px5NW16!630`*gJlZ?jIt7oGy%K)33tn}-$>CUZ z*kyC|y+x@DvrOm9*;12^(+Zxr{q#~kV%HRvBpHrFpm>>-EYnJpTkI|?cNo_*m9INk z-Tq4(PVd7=KOCmurc5Oa%)Q{>c>BV{LkY&dPg+q!iPl>?w0;eVKcmWNN&j{vrT6=L zjY^?_?#18w!`P!$7vxumKSxc!-}Gv$Y}}0w){)G5yX=bCIc0ld_(Cs zL(&xEjAgxtDCug3lCjUC$R(N*9T~(QGtI1tzgL{Uk>zYdf+>rNYboU$>B2iy_Kb<{ z&x1rcE{y6C7Pd3Oee>CU1s%7~D&w0)R1|pGhw9g!qg6ORnUx=RI68T>v8dhNOpX_F zL6HwSlT(6Xit!&q4QN?otDU(wcZpBTW?XvaUVinqgp?7syEANl{`FST)|!fO9=m$y zlYHX7D9&n8>1 zZy2Z8jT;hvjmC~I7SGq?UNXA-`^$-r4ABMtFl^qVUO%QG@ljn;Ec^AFZQR%s@^n4< zcH+M)PLjUAR``ocJuYJE;R_v4mG!NsbmMy+K>4dq54e_|CP;B`%ezY5ePX}h%^tL{ z{Jyg%l`vM7VN3n1RY$}zf@4%}ijwPiAXvVFAQOP!c?W3(jO+&W-bz%_1Nki)bVgH7 zAVt@Vif=}I4Mf%JG&p~4L z>t+F$FLxUf7Va+}8Pt*N8-_oHtp_j@ziyV2ii0p4Mts<`H_7D4K(_ZxA#}&(6O3da z?=0kW_v&nVcpwOukn^VoaHP<*285qy!v7LAm-?TEPG|FPkt8zot9ToZ^B6Z7!$d3DZqdWRj2{_d zbn}&#GBvdyX=|rHpQPaSB+H+?_VxUf{dO1itee)Lk&Qd`M(w1WM8Tbs`yGuoveO%X z)X)AHT9)8XBl$4h?Qec0kW;9RkfoYg=h8QdN$DPsg3N)2xsA2u9#cP?ytO4|+1_F< zk@H1@4Fx~b_3XzllD`TH!bp`aO;kWy|D7P9obImO+K-D@9r7Ou{4njJK4O}TWna3R zu0ssB85iGOY0-1`#;3Y0eJb}m1BaE9%UtQ0mWXg#*-f$*78{gqUf-3^rwYgG4Rpo9 zEQzdzvbP(3y3?s3yf8pQ_fhKz-y@B`{8?HN?Tgl1<#!%ymY%3p5P!c(TEwSrW-8^1 z31cM0ggC^#CVrI?g}LyLxK6%mHjY@NOe|SwX0a6`SO!7?jIA4 zH$DhhK-u}nnCs&X-%lrEJg`F|)K~G9`q%`5^Qm*Qc*PPtoG)$I{>oI~e7wGNl!h>8 zKJJxpp=hF8RsMTdHbJ3`n+)63AvASATa~u?iR!{9BP3k-Sjk8k`+q7=bd~6+5~td* z7!QgbVb?mwT=BbM3 ztIt$tkzKWNJer`c6f6!UW)*%d+MtK@?KeYRxfRukE*&M3KF*W)@cLt&!ebbo!i+** z>#f{iF$%+;k)>Zat=dn^!q_PvaZcqICT6D3)%_x_{L#@|l5inurB&$aRcEg%y}N(x zdGYfbCHi+iv(c@kD>!uH)90VL{LF8-ymf@*{ULKkV{Bo*8op@UlM)@*o4>{S-sPEu z9ldHi)o+mPpNwZds7fFu^t7}8%p>>svt9h%M7~CXzHXO~b6CDzb&CCLkr<~}BoU&( zys!82TBzO_8Ze_4(JL}L@2QPnayMd%@!r_`h zm;L1({6`kL=ys^u(pX(Y9C4rW-M#KLt~^`#KJ9RTCIjVf1^gUa-_Xi4;kFUl0>(Bb zT6{4N&$9eId@w@R4QHwMbFs-?~8>G3s&eg<fzzXEj1yF4$C}C@Q|ivrOfp*@1DxOJ(NF`(#z|Uru^o{ zi`Y=pOz^#sz0WO|eCbPJdw<|zmBd+`vOC0b^KKS10wv#gg^&T;Z1Exq8ks^~%Ru@N zI9ifMcjweRz@sjol0$~2Ku-VDlIVB;dON7((Ey`hjXvRdd}&BVC?Q)hYF2m*01?FF z10;cgMQsvKV+$;75PWWJ)vEywAN76`z#qIFO$lnYo6&A zKH=JzmttK@T)s7QL_NzcbR5pkDuILaf)#574fj_D>&A^}~V}8gB>WY8~koy9<+i1!|+L#`Av- zf*Z?Ek7K!5&|$hnlB`$!(9&NgQQz{Rddl3x#b1BEIG@GbQ{a>0kK8%aLpkkX_m^wN zft{y$M1c2v8r?1NVk-*FpJj)azI|M$A0IR8!Yh$6i*e3gQrukMl*&BMIX&%X{^aMh z8GBafqfZI4qiSAp)C|8AFzxXwY5>N3fcqDt1!b zR+5Imr}pIg0-wSh)9oA`z7y!~%+Rgz4<+glugrE4b~}M}Tz4IBxMf#BsYF6%1LbRS zmQ8)~`l7S=I1+4Fce^nHL9HV%&PX=Jc~@@5VHhY@Yu;WT|ItON`@;35>`gWd1B77n zSV<~|r_*6v*i#+%GfpP2eDreeTHI`0%3rAOJo#yt1AlbeFH4MIRP2LLnO?|Ilmy>VH@P033QLaZ%dyDdRyy-535ANLe+wLDvKOpS?JZq%>JK`jM za%ffwqtUIgjbhf;uWhHte<`c^7Nt8E)tP7)Fg@uj_NtdtT}$?*j=$CCNjT6OzV&J# z+QrgK^=QS9E1aAK9ezwZpKTk=2p7yoQ~t1?r23tL8)_E%CbSIen_kDR`;^p)gX|Ix z&!+uGQPJR!c<&~e%Cw5jro8)_ZO%mj*#d?sxBP^JwtYXIR?~k_kCpYsMoybcvpH$H z?7XxXU{S+2GZA1Leex(+P9W_eNBGZLf&djMzTx45t~`Zze*MM4bKbJnoHf`vwGVbm z8HYDl=oVX<9Pei~c*}`rHPrRI@|wKrbS?17X~ua&A=(*Z)$4s6gPx~CF;?lLk)jfA zh1whX&t%zKHmVec3Ij6r$Nf_~HzWkJw5V2U-#vTduhq5ej-!1nGnAyN9$>3d^&LlE zvGYDiPHDfm7rb^1ttdE~tZY5YlM2wdGY0?X!pn&wsGi5bP(N6GO29m;Csg`6mw_|? zf>dhjI1FLbz?^ifDjS-^L*s9dN|rXPq==9Kv%&w0fq=EZ}mWzu`&g5`2E;UXvZi-<%uTRkeFH1gefY}UI8qqxeI z?M!SIGpa{qIKoA5S{8Pw5fVAMa1e7wvNN#Z>5%fcM&Tv=nQK)BrgwfWVO z?oXTi*A_k%zD^9U%_euM&CksbFR971eE|Deb?2$VhakH@x~}(z+mC*VSFGrf4CS#; z92GxZx|-re%(`k=P3_AxL)W<@5!Om%75D8&S|Zzi6}gI#2*0?L8rL>e>+1Fcsimj% zGRirGJ)8WV-R#SZEbn&p)=4GTFfb=`33jwah8Uy)$R8bwSQoJ)4KRp?=m32#n-5S< z^7U)W>?@l%7@sBt+!yev$za^HtZW>gp*}28(c4foQ2(wHTPPwWA$T>6m7WgFY&HRx zpouvF0JmaJUJqyAhR4e@YxO-tci5~SKVKYm zVqi;q?i=T0r^sB;HdlGZUMlUfV#W#O{p`3X{p^??2ghAP!pbvGL)J1vD3aOm148)U z%V#>pk<66RKwwhleImXdfu06_ffw z(Z$x|G7i*Qa(hJuGJ({77^cXwi~68GiP}nl@#n8Z#ih*7&IE|hF_M({g z%iFO^1Mn{^!7l`nLyw&iJa!Pkq67`R0RVhh!&no>2cSiOg~b7|*={nu@c5}y&tb*~ z3Mc}K0SG1vy#c5BDoW&74lb_Bz@JpjjNY8>ze9eK_7SUA?_!vy88f_<$BWnW1Hvfr z*v;*5Sh2V@FC$~%fAU&CJ^}M4B}p{}so@Jw0rlVF7PgHqv9i1J$&xjX_GItqF>Qzq ze6qiT|F#<^McAH}P(y;nJYCDT#8^mV~mA<k%+{C zT(xvoN{=p4qq~Q7e{K}r$F_7IrvGK`8%Q$cj%^!0G)lWwWS5}n_%^bX?lmaIMlN)U zen{R@8N4~7uoOp^+W5ZQv#b2-qS8=pRqhm(g^EPA?RRc=f*FA-8E5zoe)K z{}oR#KGlcR8bLV0Sa~MsVKMM6lVJW8nP`J8Hv^c!AT&5Y%6M!iRAK5~Y~<4;G)ZRO z!G+A8P%#ST%8}1JLaqSmH?AU37yRV6;LWW;ad-}BoieHIs7Acjzq>O&XYhsSyeMP^ z8{_Yvt0jwgpVegB69+%csXzL;dxnXjRbaV5>;lgd0esVd>*{%5bI!CFKdwJ__MHa# z&{$9zZ>9*78_~+rmEDxs2*-t>_ECBIo~qCFyQ1gpvM;nfteW7H$1;hYU3(PzzTW38 zxBjT7tF_<^z+^yP_U;LV4w{oCLvE5!d!JH92t% z=4niY8|NevEQ0*__3keOV5tM#6)v^^`R>V34$TullnClILUzxetYhaliOAXR|JBgE z(-;G?1lahu!#Wy$p*}ZmvI&<1FydRMvL`z-wbP8Jw@CLrT5~@4|Cm(R-aboM68_=rYJ;V&dy|huID1Lk5cF*QZG?l#L6oLf-|px6 zS%6taoyjKAkGx9Y&H6AT*l)e#Yh5k!qkmj8ulx3`{c@AXkWsKxR-ImTVLnPsW-@}0cr(}n()`d^Qbw3LaBD&kMjb#z;+%rBgq zBy!yQxG@?L@4I?DSz1XJehLL2&KYN`-EkZ@J;8qch0dkqPpUeh>lO;h!u1OG(PzXJ z2lp1S*1Fty{5>%uUsa7QO85j*2lh97)!Xl6xQ(fP2dC;Ub1Lu5br}+`lq6wf$#}LS zBj>lx!|iQ=lnV@Go_o-_kTX6ZoI-vKnwBi)!rsx_|NXtD>-284l(R6ipty+d;Wk5g zyXm@akH{h9wZ{}EUWLlES5Eosrk=7m3c#@tO8A@Xdp~iYw-BGA_zP8bG_HP~0|qnE zbG;&qPZ3G>>|A^6YtQZh*y4)I&_|eFR} zR9t!F3&fO5$MFD~0nHyH(?|$shXsG_Xy@F2VQ@I=FrFu_OaC!%#}}7|4b+`Z&22ge zZ+$&Y;yydMzo0;?0Clk;`|D~4ib~mx0hp33hY%RC(MLu`EUdsJ%N>deTH+87MK6{Y zVWeZ-9LeZ0$*?}Jv-5m%JUdR}yZ3Pmr%*JcrHmJ}^TPJu2mbs37c4~8-~@dZ_!!_O z{e4AJtH9y~V5i1+26TMtMh?ZtT^pxHPo0U41}&Uxe5uwWqx+Dzp8~*-dMDvXfvEi< z?~LyTN7il%m9W&M?fv&99UyKrgUPxv>DlX`q>II-X2am_mja<1MM3S!MsG|9ebnw=n4@fMc`=!?1b5B@DQUx28bBY zL4AHG^=sz~z=TE|!1WPTOI%XozIV@q2z-GeSttT6dx!F)DoSr#tMKh z?w#e);v)6kr}zP>Q3)5QztrK|KsVt+oJuf5FF$g&Ua~-yIs6-#kD#~kogD1ehkyqz z4_29yYXi?{rp+On)&NVdc}K8?f>y-#HgLRm zh`n`~D%uw#?z9#}9(zA@%2qZ+t~g+XnVFeUMC1+qP7tt$My42p@S<)V24r2-t?{YY z-w7Sv{a5wwa3Z>ooLIG3HO@AL(PXklud9;M17R?1=;olPuP&&u3J%T^FgznalNEns|m042z z1V>=UlBj7*!jWikd)$02;&klMs9WCPK1h`Su6g$g(lax;5qb)C*?_-LUzYtX z;*J+Rn}^^}ht}~nGwetIWv>p5v0vq)vNcE-zJXWH6v-&%hW@V zIDmiA5%@Qr41`FnV`D$^Q%{uHi-E9d;xRgKbWKSKJ3) zn#ZZBeL=NH0CBX3LgF{bh>%-3gdXwG&BC}rCS4BHF!-bj0DB;wYv(Xxqd={_0wCfh zHa9XZ14*f(;;drfY)2-&%htTSMgLb?V0Pv7*8Dt*f8M&-c~cPC3xHH<2uzK@fP@bB z(4po$)7wsS<%cOKS7G_Wn>tpz4Pj#g(@g^Oot(DjYlL04gz=;=!NTCRU$_sJ6kQW8 z$O8r-Mg~waLIoixucB9vz#+Y z-{fFo8W)8o!q>H-qc`45(gJ2I$59F~Q$U8|-w*JzczB|S<^;^zx#_}w6oIi|(;^z? zH4KuR_@8UFR0x41Zcahhylp&C8!hvVl#-_q?}TV&YhB z_WL}>~mVYd5Tl( zRKLq+dgo`a%m-2G$;;ql;SFi;Ae1v;JIetrB`6|-5WFy9uw7d35q&A7vlt9ek0Zwo zf$fKB8MppXoS~Io4Wubv03tc94SFLkyzePi51`1QQXfTU9hh@?0ZnHJ2r5dc?dbFZ zZoDtBiM|JvBbGUpl#4hCoj4(UE~;50qM9G~S0D7t0%ZZyqCHq(+Dd?!Q!&!dFqi`+ z!huQ+(OVF~*sLwd9hTAwaM^NKNq!Dxa8=b)sD}gfL43H|s=*`HpVJobKNd8Na6R!j zZ#>3af;13Z#r%NJ90zYPrxAEGpbP3O4Gc)J#3ND~5C($~WG)-Zu6_mFGLF;i2_cYe zSPhih0}##>FvJwy!2bPimPG|k-}&Fbg5d!^72e^)24Dq=Y2h9LKAb+R7K+jDK$St2 zB`^l?K%6dkoWFAnkM=oCasphJ1ojS}wiJNA;F1sq*($V;y}>wU5Qr)$-abxC>kmK# zn#A#S_rGl)E%t}vW$kZ53=a7c0D39_&qVX+C4iS=;On3+7Ptf7Uq1O~-9Pg0O(tnK z?BoV#Hk=Sib4`8@$GSe_q;hW=I87;hGm!279PmJ$Ti^5}od^-_G&3RY#_DV!>1w&vNPRHYo-%tkJCpTwL@%HYlJ$!y?n znDl>@04G=fZj@{Ry$rz+MW0L&Vj}iAiO&E-8SlEgV`THtX{FN`t&j7!-uWm>ZMUth z!FxxxDnFvcDB(`mD`>pw5@&cF`Svm~J&FU{D-@PsfGcifOIxw|39Wbdc28()tk1HD zh;(L5-|<(O-&z3>E;N{K(Ga63rKo*;7t|9CVRzOJ)3U2c?@koRT$Bf2JPe9$5WpDq zso)SrfaFdR|4R;uT|gg0KARkxIe2gw%K+=-Tn9KgNscwXw_5oP+$qFE%dI(9p1|V! z$$s&?(W% z3%T5uEIz^x)^wmFh-1-*MORJq4MFUnQ>Kw?L=0L%L)%DLk^E+Dgy{6S&Ry~x_LIh) zOh6R>E^!os0As90&J=F0DX?H|M@X&!09WflL@8`Db__UQ#vMC=kMV<$0p5HC6>vx2 zTr}iySTaFZ3AjL2!sa4jCvaTNlJ0FeAg zfCK>UoK}0E!q$FRPD*o=Ew){OV4nyPQV8n}!Aw%?VR>5rZDB^FFo3LKzl`2*hiU+Z zL4MPk*^;f5Nd8i)fhUSYM>7h{!v~D?J??b(-Ev_1veiHv$-6i!(b@Hp5Do9{Oi29B z;5`k;6h848LBvgK28)K9kSu1n_wU>psOjS-|{&BJ6g0PIBXmr?=+AMLbo zkqA{Nm$j!-B9Fdd1GZ2i0IABPV$)Go0l0kVsUr+IJV`Ie8#>1pU4W4arYg+gfN1{> z5A5&mEiesWN!t}v7kL-61{56&QiP{&7xyX^*l9*ZKP41BS@M&(~o z*w(|~kTe9L7mvBRnaL|h`0*UGGQP2w7obsuIMo5G<1Fw3L*lWiI2yAz&wynQl$F5q z5^*HK{UcjcT_0`|460O7)OQtH1;UCyD=OMWC0v)@59f=Wqa*siYrmnGVmkaO{Lfjz~k?VzYZq*j>wX!6lMY4GC&- zEYF`GfRo&NYV*yl$t-7M#DARsDKFGyNjYN2uTFWE=}}{{e0H75K^L zv{S;^)kz@mDiEQN;bA4jEiNurO4kzbWBr5LB=A)R;U26OkB4Ie9vcU0eE`MC*t83y ze2O_dlv}rNC>suQV9%T`1!9`$o|hKIBsJX^NMG&`fV5WT+=UDAu)iqkAx=3jw({em zYfw`YNyp>jm6pk6%st+xiK(n}?!##NgMKmger+(|Zu2!HA-Bq%J%o#kzCYD7nL{$y zqaY*mbEcx>OEWNc200OZG-9H9JlFMvpzC%1&6XxzBg%yYdD58H(l#y}hTVib`;<_a%Hx=Gi); zidp?+Y>54~P{Z6^Pm4%6(TsjXxI<`-%$lofpj4)o)Na<6s!j_0C4h8agm}gWk^u$= z1`K4ZZYr8QV;Sn%dlV#^w6yo7y=cRE7>oW5gW!`~VRKD^tFT&WFyf5E>Wn{N-St36 zhQb=)9+}ugf}~RqUi{irZevrzwUdbRG+VMt2khtQ`0U?A=Ru2>Wj%Pg7z|V>s%|4o z1&H9t$jJ@)tdD`s%WzE#qm|3%)0{Rtn_-v8LGf#zwHa}jN*&uw_@VQ~04QHxxIJ>| z?Z6jp9wF<%C67fj=N!CYZ%CU@Lg}jt{0u9Lq zH}oUL!L3Hl_f&1m#+yLp0TbIMww9AMxXya}0KZrssJuhb7M;%U9zaITqWe|6&w2F@ zNZWo|d=A+9d9!L?wwD%FLeO0@o#Lu!&{RMiTu)gr89XgQV_UR_lluSd&oFnJox;HW z`af&WEU%xD6n_)CUJTBGUf?1&V#48-@(nZ_lyJE5Yd-;mCnhNB!bM?WD&S{+vL5uP zx>&Ii5EhzJ1^mnyIP|$BT+b@G+0cS;0`#VqMKLFr+E#8WCxoM~NU(#DL zP9{j6k`19P4!9Zs&N}dj?W0E8`}wg#T3R{xkAOl-e|`-D<_5Mi#OUq^Or@yHJ+4a+ z7I1Wxk3U)_pn;=L`PMRIm)#k6D<_WbYH+1$9d%A^fAlTshco0*;2c8$z%PB9UQjt9 z(+3TZgl0FHjs1ADhbx3!a2RTwbIT2PtBkM+tBeZp5QzZo9T3MRI^e#35tAjFehI4h z0;s}(G>~7Ud^%tN5m{>jC>AQtnY=-Kc^`;LNyDe9l2zHzoJyK@;ZF-uJpU`tz}vJ_ z;KS6#8}k>`rjt7+A1?h-F5Z|(wZ9nsjwNW8Vc9{#i-gt}Zf3nQHyp?)=yQr6fX$*Z zqArTq5P+(yIx^EsXd8kg3adFVS!?)JTl`f3{{3;h^0m0ey1fM$ffe@6==F<0H4Pu1 z=-au8Z18l7ctC^Ia!aBqTZMefX-wEs;@YLtx~|8F)1Q59H}3(3pkH> zp;0)vHL1WO_m064XfeT=iVF8=Lqo1!19xXO>|>LBEr{eg#`X>wl(zqA!m%n@rU18< zETaZ$h*Z1a0j59*1As`$%_%Wu+O=Z&_th_1Ef&zMWQ6MAc@tM}f~uaH2vK((FfT_pQELJe}yD(K@t z(ZDSK^%c1B5gM5TXl#zFGq%5l(Z;s`PZV$&sRV%^62%U}2}1=ws>oy&(9e;ZNq0*` zp>+`e=2qi&98h!jW+05XXh?)o@iy>~f`km2Mq60YpE%(G*E2Zk(x9e=CcEw#*>%oZHEZu~xC#mz zJbvn%1s&{g{L+crF??x&+8|mK!>eF9^SvcI=X10)7-l!YTH?sGYJkWeP_@QrV}L%m zx+bw29pJ-A>19_8$YQh`dxs&1uK_?~!`^LNoJE4H?`k-sJ?wymu}WWfV;M5rUYX*jA0Cie$m>vG7AUj}{z zX$5qVF!!+MIS6q9RviE-3m_2dEp$*U4A99dFR5NN38E{~ z(FK5&8VV1fzb_A@b7gRx>M3=lcimp5L_#L$J32uGVbo2*s)7fU#~L`(l|SC1H6R^= z;t2{W)}ZXeabE8Z1hQaw>kEGgRRz!m3df}fH;S)(lx>F@eG247(m{m z%GHyp-|kQ(9ELolq4)KezHGxHk$gVX{`XEUmQtnEw(3kdoc}U2vQf$p_}cd^q0k69 zuUD@3D?^B2`9liNGc)V0Om$*A5VI44XflA(7xpicx)5lhLK@Q1LM>|ISGphM-SH6iQA~D#69Mpnzd@BX8<>qKiUX29Qr@B~ z-ESFe2*7c`|Dc-oJJ*;LRnt%m1JyKajcUvvVNhA2f0SJkp46)jeh*983xGWp+>g+N zgY1|2df0u_JAhw@!QlRkVRYQVWa9UQ4Y&iU;Kk+^e=RMQLVZgpBX#i=JKb#i6H{xl zkgC6z#lQ!Djl~w7F8y`XEnqJe_FNaMbfP3kqkZE$=_16balPC4xX7fqP z(o2!akcofxOX-5Eb7WpiGRh$x~FYyuftMs_Z3W2LxY1j%BR5Rf)?5WmohuV9Iwi9-OTJb z%gl@qs}~ouwzekxfP<|&_4p`p`W5r7x9$hedS`HK#%N0`^WNSw>a1yOBsDcPMWs+= zvn-sie?5T37o+pkX@vTin)k1a*FB%Sj%{KvP~wNh@MSghHVBpsO>W@Ilkaz*V85OD zz^P_Sbg5=9PRn_wm3{@%mw7k_kfcH@{g+`X4eE@HoqoTyg$%8M*G|Ca;zJ? zc{oX;Bu;68higwvpD={f^uG=JHIV`z^Y5Fzq&XoBc@32Z!+q0Ukm94#E99aNE0tL2oBtf_ z7TFNpWIaNfe|zV*@HuNx)}&YrBy~*AgEKkwRWQ7~yofw|>z%MWko`sUg+mlg@Nsg3=Z&D z5yBGIcy3Um9w^=2v035yUfZ(dFx+8z@GO|>)Q`EJ#Uwl=IY1#*G2_Z7svIe~_FLBT zHFkB28YXlaB92+F0m1+pbXn0T6$@9lSOZi|el8s(~Z6o?`WwMc6vjx2RvGNgr zFM&7~Tnu!v+#nf@h2$rIifHt#)@yt5-v3!1aTFQ{hopwce=k~zWngwT5Ux#h$15;* z8=r`2jx10${C8Np#Homr(kzQcos_@4|8^P}lYc^9Lz}J5Ht<6$<@100LJ{`G@_)bG zd7|#rksB5k9FUFBI8(7<;J@~U{I};%!nJv@8$4fB;MO2$GDwV}8AAn)G%^Sh`(1Z? zmA&I}FCBeNBwx^Zuzl5g&`*i`-+mkSVW`fdQkTw`C&XY~2~IKJpapmmi>2-@2yv-<#zSsa%50EyrQqq1TIIYUM2mpbr8y% zk1H9IF+7>p_YeGrG7LcOd>ju4l%Ws^YxdpB>@YyKJK*>Ota|tY&6&DVU?tJ+9&bw` zG-&^k`W62E-uCpBQHg@E4QhNORHBcQ=7~Fzi4Ztnz-B8W^M)(XU?D&?{ocQVFfq6R zlR{&gz{+aY!fF$w+4>--hF=r}ZE|qcXX#ZCv}NkjLt*GqnMqVl9vps;4purWF!y1Y zNF}4Avz*_n7r>^LrwJy)aNam1f3S{TEt>LE9|JWEl;!`xqbw-sg|Gk>(m`T70JVt- zK4TK(Tm>(L_Tj^aR1yw(&#p|PtP8qb`AFfPkB5PhHw%m>vhHxZd+yD;^zN8IncBr)v&0F|#hoGGJy?tKv zwdnv0A#R>%`bqhNli>xrfzwSu34U;PzOtbOTs=f>)zBtV0xJ@bFp!huLl(gQpsP(^ z6B0NaOmlOy2;Ik2rN}&0PmD?aRv9na=*oJrkJUZy5-M9aqxg zy=2zit_&QU$0O8GbeRVN`dLU?=$+=I3knq>M}TgsGWhqVhQ4BCWW*@Nj6HnJ=N#>U z;S@uuJ@|4^lN_a?mvA^IqB{)z09sp(x|+Pk#Vh}=qK|JyZP}ni^Bmr!;g?W041mQk z7&zY3xmFh!`Lq>p{r`IQD{E7F=4qm>zE`BnD!c<@ZeF*YOJXUltrJkZe>C%~Jvg&K z(BQ@5TcFOB5F8x*P6Ce&V@IourE;M+&kt)Q79O5`gH82R7>zJ#Vsvbu6yX&&d~ff0 z_8?2n8NnD4gGJ*!qL7 z^X;3S`!Kx7S5T}5_1L9X&2)UCqSUZrC?>tUz2l*O89+@mdJU?Zh-^NQ)hT#xswOE(KhcFZuTkP4umWj4Y@i5Ezn3rZFn%?Qd_sw}oCMNtw;B&L zSXtIr2i9Ois)7d*)KLY$zQ9Fp^{Dxv@2sng1=VfI9oFAGpOVe(=HuJ=9(&t71}RWmZNGaw33Fv6SoqI<%gg(jkFVRC|rRK zLXr>kPU!?N12hxy#kq!~LH(me; zYG6wj*olK~0G1X-pcK@PU;&~Vp7K0-fEgq_`;o0xn;B#|%9w-fgiV&FMI|51EQH}V z59^SERzIZe!#=S}E_xTe1k`pY^v|h7jEBrz~)q4h5Axgw`m;@E}&!9Dxncl>k13L>Sbrfg9wo z)Xv)38IVIk&Ke6&R_HODj048ne=(Ajk&Q4L-k+s^5&BVhh?qHWWsyn84x^(jO*5AV zPox3$+*E=4L@N zu%J~Kf-4i_23E-N;B5?+=f^>nOCx|+r}$XZ;JR9asnNfv zbW48lgNF(?F6QyS%Z_b0_EsIZ^MkXL`VPf z^?W|#r~QCf->z$?jBeSzj@?g-UmZP(SFSObbn?HyD15Gbec%}>V#ml&y(LO| zZOlzwN+G*w_^SCZ*=c87W(2^Z8wyZ|B`(SISPmke$VT2)X8mqEtY`jv#ps0S=W zZ7XAw+=S;F;Y~C8Qq&8{T^S#>-iug#zLQM6EVL2D-%szss2BEz-RjI8#IiAIiot>Q zu_yRNGO@4_!e_IAZqf^?Q}Q%n3)vU0U281OJ|W!aN)_%Mf8cBB2xad5Vl0{R;2=n- zp?QI{6_B2xYzwOHBIPd|8Pz{B>xt+~fB+Z2SV+4dB!04Iy zGZ^HTf-NeZH0;kc(|`G4lCZcnzgWa&cN7?c{Nfwj1;}p* zSp>lkgk2;^Eok8bb;q&Mpzm@a+}bT`wqr!{+=`@mmRPC(?d1p(1pHfTHx~GL3KUHqZb+l6DU(-Q9r5ACn2IE&lTm>@~-_*<#A z3;WN3B~w_ynjhSjS&=FOOx$c%l)fL~hq6wL;_j|rehM}Hru{N~)88d~36A|wsUygb zpc6JDe+E}y+U94;k!wXE0y<9x%;0kN!;nP)Y7rmZ^WbbUJHvZPtW}(HawHof4L+Dz-lv)9s4zwBV)a80#Fn2@rX z^AcNv_oVV^FTOH-@`irj^B`*M0>fJ|p@F7mG<&zUmq+hqi?P7#*6;YZCYQxNAH#;u zp@A91wqqM_2Q7Y1AF`uNzj7UmB**QmHsf;;)0Er?(>$#hq#?gHtLusJi{r5qvH^aF22(JUPhjw`)|w9J{jTu%qh(wS^x61GvTMd zb{T5AoHUDx6~uXxy~JBtkV-1n8k zOS|c&UF43^(Jkx|+~b3u2DuK>BR@0gkK8l3aU!EtPnhwRg~;l&jlXkua7p+Z3R5Kd z<#+UGyu8TWvVCDj)^ICFNQt6@R=uYzTaD|qT)@>%i8qf8Wvi!pocJ(C8~%ppGctbq zi^cgtc}E+{cOl2#~;1^y*ool7DzYb?4f>d?4&a?2WlT{ryby1IIUjG@{ z`p9iOu8Ge;)G~){Nr&4H#uTeon!OC!Cub)1I%O7=X{UXvH7d208lPMiO;2y8Hp=g3 zbv4${yuD>n?HRF!86{yc(u=O-I`ZVW)}=nZ+4=oOn|^G<$#EUY1#2&v^|Mu|gbeywrwMq|7n*&nBVQyokGjChkLVV{^|oQ^Whlf+L#@0uobhDc;|F zen>vpWCQ!fYKFRG0mUMo-yf!6ku(`HD*{Rs^%>5d@qE_RLCq*bF7pFGd`+HoL zQdi?#HMk9bQMhow=cO$(J`FRb+uZFRt$BK)*YvOJx5rI(OD-E^Y&e&>fFAxzH^Wo< zFC@~w&h%RLYr(2lbRH$XZc(>I&X^4To_u6`cVhg__MA3*KbgQ@Y}x2!`*ho8-*H{v zo^_+aMFEuGKVHS0*)lG(^TKCd3T^x{^;q7tsaCun&sGUOMx_{8b~X3XXYbBD2yeWQ z7?UFMaI;{rg;8f}^Bmy?a~5Wmh2bW@Cnn6$iB5o0PobcTR{({W3@c}8vL1w!M0zq9 zxNu8{DrKPqR-|O-7Pwiz%z6 zzlCh_=4p{QQx;jMA7;g;CvoCHvgNkSBjT0!eh3TJiY=5|ARM;BC~NGxrtx<#{oM{` zEE$uVKM(m-F1qG4utSjRO+kT4+UHb*C*twT4LYLo*nZV~SjV-qdoh>P23}G1pvO){ z%Om+XkIG!vNPFHjepQjitRY~v->2FR8*}Ln(-?Jk8`($oi_%U$bW>T|!lX?RTJg!J za`;wT&$_6G>q|c_9A30%)y>I27ryM#$x#m2n4&w(7rIHGwNkTHgHv~T8s8fIR)MkB zQ%#8hwqMP%f;Wn^9kbf0AZMFfVA~h-^yQWIeqp^`jC*^vf-c^3)?Re%b=pS5#a-+7M6@JQ*|kuGH2irNCXZQzFA`Fup0cZ9kU;Uz=iZNpoP9U3dKvoQNOK zs+u<`rDobjhBUV8IzQ}=TzTN?+q`N+KzGlm9q#sI2qU_L>@ri5HF1 z23;fcl-IvMrc8VZ&;Ds?8+xg0Y)FhS-mCYjt1BoeT{-O{k`oG}849Q_=YOgT$$vSg zVi*L*`K3#jGMcY#pN^6KDeF}(!lft7rgJdXXA7Jffskey<%`Mb=UHp-a#?qnTbUtN z)t=*mRb6MQySp<=P4n}ggxDzP7c7^2NrV*itZmbCoBP{wG^|_i7VIt^pZsObCBIho z9ZIa6Nrq2>Rytp%PFJWF(WRf_Hj2mV%a}Z;TV*No==AgBLZ{jbZ*rc`e6`jlOh7{a#1auL5c3Xt!Te8r4>gmqI88 zOF1w5_d3q7h?-PS_5Pgtc|%?N_u=rSU(X%Q*HSXf*9K0T49feLONX4ca*y7lXL+iX z#q-ko$D7}~#QduJrrD)wQ-Oh#%eAx(v!!Y3r_&cc8N?|uprNjum_naV$PS_41x~yN z0Q37=qEdZvlFgr)eRP$-SZ>@$0b0}=`DA->wx#RysJxWm#Y?+sEAI`g?Uy^IrGDf5 zLCZYbPu8eNTC+{F!)^COhJ3I5wVBg=>edG*JI+r-eiL4Tg)%(GnIG1%?q)eMeSgIt zf4`E3z57njXxkQC`LuJquJ`<=RjiEpM&6tchpGco4)!HG#mcAM&(W+tnq#rA!C{cc zxOkWB5Aklp7@l00Fs5L(n!!^xiOJO;1JAE_v{QY!yZVyRUh+QN9yqX_T``c_P0MO* zTN2aWXw%>BN{zWP|na^^>(a)7f|a?;t)A(k0U zA1Ld}Ut~HoHqxpD#IuIJ$Znanv;cG8;ld}lS}73LJ9uke(Cvap;5{=nbm>@M6vj$G zj-_Sl@yrR$-S)#{HnMpYRBf`N9-#kc&*gGN_p`ZRkESLOggtPda%nF*zZ|%f3JBX< zLAS(E*&s6;D{WTpTV_%fjW-=Kgt07T=bit zVmCR_n*asH-p@7YpPm7C*#`a1mtD6Rg0%9Y%C+qP7%9aCX1)$NY2PXr&mvP`YrCZ4 z=M@vFJ7MWH#m6#S8aqn$2ag}i__HnV&tm=F_a}MZFn;-Xap|5t`v+`}_YHizJpS=( z?7(nWqVVh1ePv(iDo;;z8B!FIekN_}E-^o_G3;31du)ip4spM%k>es4!Xih`etLi6 z(e26|m9@8o10q+1rFfh*hzXH4(9>9-hX%sEGW^baZu+*`JSFF;TZdo7rtzCSkocGq z8s(Kj<1XjBKQHrr-WW%*(}<`z$I_$AqQw^b8STjO$=ESi#WT|3Db?S4zR|WQ`JA%d zf_3Eemb}Y~-rb#;wl-zD0%4Zp=_?y%LH1y*-PqTXZWEHAD||Pk@XaQ@KvDMht&a;F z-|9~mS#j2-2QBS=_s%CoQ@NooMx8$Do(6N`CDGqE5=;{u6U$yi3G;A$4V+pbeLb>z zXy3wDpXTOhqR&xjV&(4Xi67>q^N$yb2Xk}cYr;+cTF^)R@CS2{MsPX+mJ+>y?md$= z1hsW=*h+>5WEz8U=S~7bfEl*B86yc*5R!j)|b~^Q!f+e|GsR1M}Gc4C+T((TI{?P9>`=@b0j0wdad&o}| zOoP-xSgmLe?t(Q&TuKRTh7eZ(FPGz^I-U(Mi8$e%&?(YT%5p!uH&WCJZXzv z1@D>E&#ZXgdT{DJ&QS7P+iCG3_EM+0?JkWlM#-I*eih_y*3L{l{Ka`S`<9|3(iPWN zZJU&n*u3-%`;5(B_rH;Jo61y6otQVBA4=XIUXW^-AKW!|+@85TOL=c~vzb=t6M4IRO%4<$OVfIhn?aL|CJQaIwfAE`=jh5@=-Q2m4<$HgdQae3y_SsKc z`|y=IG&Uqb>oh#RkaprQQu_%fo|*m93}RZ%rTYSR1-o zuWrynBVX>_PX0%O&-BzE!r$X9C#u2^sEOZS$SDbS_8R6o+;`2F!hDbL);7F`QD zZeYT@s}$hPr|xcV%oyNbDdlS59aGaa6_D%hGY5fpoVoDnD_?u1(MlEu=Z=tHOxEb)Gr%DY4odGb4XqoJmqL1S6Vy9S!O>0D;sCd= zN``#5t(!b{QtNqfkIk2$Fa93lB9;B`(-(1g=`ol|pEz~uVQ#)eOx0IUr-cD&CpfIn z$ylU5ax8a3CS|7J-MFOknwEcea#XyZTqDiTR(tng#)jaLjkF4V$Hq3*mly7BcKH1# zi|eLr_GsT*{T!nP?UCvl?gni|9j$cX?z|JGSc{PV@t$3h>G92+PkrkKY?hM4xFG+e zS^c=kx^a=QIB@r^9-i#;=snukYk2SM>6RZY9lZjAf&-~ZDC33#l5 zcHnwU`;w)lF~7P2xf6NG(D-=1tyx%9^`M>g)vLnP<8vc^ZS&3_rEf~}Y|3;G9AE}( z)vGQbNPX+V-8{w43D8Fq9iW# zxO$(30xgs;;84Q9e02KD0!S4X^10os`G1hs(3HM>xpZDzLxzX6TvB(6S@PF9Bu1T9 zRru~}(UTaR)qbJtm&hTv1(f4kOe>6e>mkB?6J2qcw7jvzQB=6S_uea?CkE4s`MmURybZ}T}TOH@PZn*flN0fIM43G-l zBpI_WpJC7D|90Yf*U0e5aDhy7WR^qe$5#s%>C@iGq~Go+lW!y7Jl!Q_6SLpA>A23b zfUF-H=||B#xSoGbNUKHGwu4BO~_-XX%TRXi0*-Q2%9dvv4AIvlzFDM9ejiymZ zV9$D5K%>5(z>$~fb>CP{oAl?x_kWygykGs23;)i~_9I(!w?omYrIwz0h09G=$Mo`E z8YutbB`9RLtIDd#LaG*Xwr$WFkv-flB|J9y9MHWg;hhjlA=sXfdB7oKF`oh;iOY59E+4V2?~Qf{@6C$ zKe8}uk9R^~kSy3EiMEcq`N^_P$G5yozixH%A>Y1)ifflP^_~?rPM6Z4XNRxp$muL^ z&4Qlub=4;Oe=T?xVEa-;@S$g4l!@{^qs>tvZfz^CcdyILJ0Gk3;^nQSTqPyDe!T0G z-BfV+@hQ$hxeMzGqr-i6sxJK+AV1yehXDUVTGWnVdjTzL`L7l>6LHGPqh@$Ijf6SL@@orQ{~flMX+#(U)@|s8r+f3|h2GSlwXEeZ;&8 z@*X0AhvW-<3UMMN0q2~fsM*qKq-P)DKtPF73Ud1}BCi>qG;X3E&&^^79g)yiz10eT z`of+`&kl5+9s0B~8tSx9ut-*_4qnf~!NJ4_K9t$E-grgqi@>h-)6ka@Q$EaeiAEB@ zx>f%Ao96OYr&?uHoth5`he5C>fQY^js<0(mi&uc2#s8)6%EYefVJ^L_6Z^*^Hth3D z>MJlfJX3S&AqFk4@11T<-*EPY$^PGVgF4dEg;$w0@*T8VIQ3hL_8TjjJh%(1st>Eq zDM&hq+aB5?K4pKVhr%aFb6pmd$UFs<7$T`7YuSNZn0>mfR!6G_J#C7ovD}}BeB`mQw4^8q${@Rf@1SC1 zBaA#}aq-L@^tYuyuim^IHbujPJ#WM8$CQ-)6@#Ba4`wTxs}_AH>`kVr&@XOTiwBv8B?lU_#;zFN zw<6nlw6o7CS~p#PpFgwuD6e3=HH&=SDI@lm-@RK+VnVjn2JPHH@zOg}cOi4iWYndk zzcA}yZchJDaAuTzVV`e*cw2UEyVNnQ7nT!;Gn;I(L|^qiJ#p`+c_rY`_p$-db;0pL z^JkQXGlt-Df4;8el67GX-*QCN#Dt5Mmi7QNAWL}Vxabxv+3NH*XeArldd{=6e8627 z9133ad~eiX2t56m0c7cO4z*>4vm=@Kzb}6e@o=kvPz6D@@3m{sAVl@g(yEp?0Sb!n z1)di#daVmtVP>#g^~qI|A6-WkEf{wD_^MjBOjKOd+ovANd8cV{#qf6a_2mM}9%069 z(tb*DrSJK(T%9WM)j_FAd!SV&|~c_*AMu=pQeKqetUgTv`sxk4O@J z?7KfYm4Ho9zv+i_5=z+jFIQ_S>ny4l8%igeUYO30D|>L^ff;yKvEklK?8i+W7RKpb zGi!I_JAA?*UBw|F#yDGa;B|&yH=6as?j->~;r~E*HB5j1TGtdQKLGEo<&^oI=m8Te zt^7l+1*IZ(#!{}U8FWmej1LtCR7otl5||{-HS*ow$bFJx1}Yq6wDhnQ5I^~D$rZl$ z9qLMEW_)Bd0R+mw9_eT71`-aP=Q8qa@TeSb{c6U{9_-H#gTq~d!g|T2+LiM|{;dvG zJF75+j3-=m`}&_(OtUM&e!CBy(r91v^1Zh)I48iKE9rTIs!Z||zsbDHDjwj1!WWMaY9kyz>%DpW8_elj&E#G7ebNxLYN+YR)HHm{ z-iO_yb&bhG;S82V(VJ)AX!rhQ--?Y?dDc_GOu^8wDMiPuHKT=WMgE zNkk6p{m85%ui=N(M91V!NeZ^(-*|~M)|e_-xf0_6mOaK#E=q%)_OhpdcgHC%h)_KE zm9G-}D`KXN6Fg}mUx23!N-7NnWkJnPCz934OqtkxL#IO+H)q^z@0IO4VcMD6@GW}B z>e-7}vwhjdaAWj+z!ALg&W@sfgkTdJB&3Nt3N7eeXmtwRHI+G=2&xI(i<3rx8tPBj0=O4{yfN~G%O zLWm$`2}CYXY0^-jOj%3_MyOF?x_`>#DY6*KLIxllp7_gg;LuS#^76#dZ!l1{UYz)0 zB6aZIp2aK*F12eFSq(p^Gh9Iss9R)DrhxBR*W8#~{^h_fDdxj%YIlJm*5?x*g{2Q# zCSoq}0!C35576#FDN`mDtqNk{eaCJMs2@euDmC`*EeFoBM4OdxcA;hJgybEwB1Wtg zi0+;uYX~5r5_8YiNh4kbM2`|YWV~-pa#dA7bkmA>4a}S(A704s*t3u3Xia9kHif4` z?n-3M^sfTlAC}!R0y*BzLB&wj-~cF#e-l#e`9G+9#(jS~{huBOiQhb0oI9r(itKm{ z?qH@v0wyuL0856BnEMjx7PwLIX!zjcyVw*~dnVFIcH4<>5|E?2L>=;UW>Vho9C@p7 zNu!i5*0w`4iA`{*eQ^om}@sr@3KIZN%Z}u?!T?tOI+EXPW|G9 z3k-R4xnq7X76I?k0%!vHcObr{qY(8CGoBhc)t5)*Ucj>AE`+(nfC2xrC&X??!G{I| zlL%gJY}j|FtESL#x6!jMUIl@>qIa__(tc)c9gnUf!GNjK7TI zOEOSU+LliMuXyL}udz#l^jx^1hJVhcNHk|3Sr|B-b_GN;ChOP)nT# z!R5<1%J(ZeX(|}{+_-ryo)Lbjb|bdK!^4-c7`bok(V5<5(-U8o6Wb>eX9`YzV1gp3 zWq|xj_*!v19Y^qnnF4VZBI-WCNENuz#3a3C>R(nv!+2~YqT9cH7+xBeA7Q9Z^J93d zdSqC8y0q_#BVG2|K3!opdd2&{&tIpYc0RuRG84sAZ5`A(D5Gml1)YD7*J(H1@bILo ztb?NuGh`7^+L6rzM7c!*m-k}!&qYH|JRwXWE>i%vEL}e$&pg9X%gMj-pHk`A2{+Uf z$YX?EZ78P-W;ICJKAWvnHU9z^1os|Pjbj7p%yJvTKs{B4=V6d9}64(we2Tc#I zTWpY&F`>)(B7--shU|4_;5J%IM6hSTJ_sfPwitPVgj;!dQbaqRo;ayg__2S%Z@bVP zcF55#vp!ir;-aqszd}ZnwEfmOp2zVh$;$?41Gd4A0zQT-2`EGy9t!M?36H6sV}*23 z=jK$MElLW>JJRMzP-|caMz_oT?`qKMh{rYy?g`9rK0Y!*afR_VY&GsF2YkGDcqc_c zS64US*Gg7adOkO;@GH;Hx-4M<@tVq>y1X|u==bC#aWsL8?xlNqC2@Z%4|Jy1-R>yg zm*w<%j(=07hseeRUyT>IJ+u@eW;FEvM3j`E6)8-lT5-Rarq|Ofk;6ez=5v2bZ8$RMng5 zSJ69taDPi{Yis+5VUN#|v2o|Q@0J(;!b0Q%Y#9J&_=ovX9QxmH%Lu#m@#4r*k}puW zf-x_zW^qp96a*FjbU3RZ3;U*`?Ij6*&N0pIuTU?}3XiJx0G{%s(#O1+*ob{w+e3Q` zR6fLd36hOrXkypQx~KfVdJWu94mpsqk4}xsGUR`z>l3wqi&Esa$W$@th+fVyl?7<; zyNub44zx874iJxHbbKUo8g?YZU2OI~PEXnsy$&egn9!-=X#V*%;D;lu6z08rPXHKE z(7+-+C_<6}C2~#pmYe54y0L>|u8e@y?{0dpX0#;A^cj^HeI;#~bgb&aEP;~eZ?^hS z3F5Ty7F+O%OcH?W{P>=cT~=8sxVMcPg+Ey}g}Mx;UR|~f0U!b^m$;1O6i*{U{*MO! zn6o^?*?=Y zWt4n5!IVPHMvHt>WS~Z%(a$uv43Zxw9kIFML3w;tDA{K(z`z#>XQWM4Tkc~LnA#w3 zpLR$4pLWu6pMnB~5Z6)3Nx2GHLfQOsBHxtxf62=tOlrJ*pu@vj<#Kp<`0xW57_mYL zrHYS=D?z8tI*j16(?zH&J@eo5V9}mVaN3dfKf+Ci3+ZQ&RUj38HZg^ng&!X9XX4wA zEKQ*R16J15^iS|n64hulgy%ZUSLcG>(4&X~U8*R~`zmJ32M-?X?D38<*^~A+D21PZbF9(khx3ZELO#VLFQesOSv^cCRQg;QhCjqWhSX_922UZ$ z|F-m*wDZUS<9)BL$*4ZD}WOL)lVh}u&LR9BeFe*+5MdIH2QRu(2(w=qs7o+k0 z;3f76BV7UjQ|_YU0h4RHD6yxJoB4Z&c-_qYz&#DCpX&RrVf#{W*jJ0NugKIv&`zi z8d^2){~-zAbpS~isx1a(?}R(^TDhvnfNuC8rvVR@SPo&zVQaE#PXZ zQx`tb___a7kbb+;9Cn=_d#JWt4IR_Nb0UxqOzsS#Ydn8mg(=dVrytRSqK3Y*xAQ9RaqP(USNe2*36Po+D-56_n#a(a)c2ISEOObA%n>qh$tI$e`m`iu1 z_SOEwT_z9)3#>IidK?4w=07EV;)HqPQ2YzEGEzLPE;Y6hY`M_n`W^L1BQ zp1x7P!p_MXVzS+Y9N?8a#jj4q{IDN>w$KQyfdF0+c0)B89M*A}fmt zw*kY6D|`x(+w4E0DBOm(M|QdbJgXy@cUxq`T8m|BRB_x%T%9YhtoaPa0`qp6M1!@A zSGvR;H|19sHrq(QWFj9}+x*jyTj>xBDvec5+QHHvbB-Hq~nsUyzLhe){uHi}ybI|1; zU+!sn8|-=US%s1?7p~2Vii$Y2lsSEnCp9+esx9|63DuvusI^?173(kJ4oLl-(`mPvpJsYMDGdy>wsOZzL$W~FJV4h|PS)({U(h|3^65=UFM>DJJP zB6cE-Mkm%sF2URD(9`XHlm^u`X_0}?iJ^U*l~rH}Op0w-N0q*M_3EgsBC|tx4V$2U zU_1G^bG|^cFP^CqC0BVv!%8xZCel)v3R6n)o3~;VMqDW%ox`)gGW+%+ie0Xd)bsbo zY++&f0jiDHu!DPES*e_)^9(;EUdZTgWsOX)_Qu$cHn008&hU}A2ycOyeyI9rP&%WX z*@CGjts*nubphZ8fPb{_nMf1r213bMR3E!fs5IP47oZ9w6EXG17Nx(iX(DUu=tF+f zR!gL63UiGK{C7bTLo;#FiH4B$(_AuV8KIg&Oj5=i;|)NBM}>vm6L%k zfdqX9a-xO~gd4S&Hb;oG>?m`be)pafg6qKQ-zRy0Yyg}ViB@{no`(op*VQ~1ckrTE z|8roF@e$mKpQwgt^#;8xmIblH2kV6u{#=d&=Le5IbI5?_?kI>Urd~!b=i;Jo=JvdU zY`a--f&xvpl;mhd9x4}DdFpx{FdxcbJ?E_IfJHj0JR};Qq7%l*3iv^90Q{0sD^a{a zrb2=nnG)cLAhnP(pS&#EdA+Paa3jQY^$Wj#);)$!^Q=tCXwB?H)#dBjlhc4R_!lMQ z1xU}|H{*x#GgMsvoF6umraElcUG8Z?6ou0uCw#Y4{5HYq#b}IDGmDCS;JRAf}b%~ zyz=zvCDWysO+sJb9#PGu9Mb_%Uc?X?uoFScNV7&{mng+w0Lvgx>FbzbMyo4~(zih@ ziQt3^iL3!ZY^M9(oP87Bn)F13luu+$ahZ08)x47Jb+=0nY{J$SZh}e)cInTs-E}LB zqN8>>zqKuF|I4ZCa3UzJjB2`2bq#nUS%A1>$Bx`eRNbuwc7DX<3E$xYHXd7zLkIeg zwK9`?zIS+A)FrM``CK)wYtLCeYkp=t{jqUPn(>G4Pr~ANNAKQCpOvAfw2oK!Cj$}8Dbt^nA*HWr1#fWt%P9L$J zuNWX?F3MK8=$8!S+qJDC#&1Zrrzd~e;WV|fBEe-igwEFzLV{JJbz17`z7<|;gG9Yg zc_$co`OlZnqpff`0#y*qz$FmWV%2U(`j}X-gZOSN8U%Pm@etNaH|XbGvt|+ezu>xB zbvFy>qc74K95`V|9GiAp=1xbd1nuoPeQFJjm~8G!4vsr5k?YvQZ#ii0)mcn!)wn?? zpr?ciX)6s)W(MB5GrQ=3bvNnmVXyG$;%_nZ2mVO5k%xBxss;F^0E=+GFpZ_Hu$rWl zuJ1kFxKHZ(70P4zn&$Ib5RXsd5HnwGGU?Udy0v>2|--53OS3Yzv zMkYJZN>ISBvda6=&~Pw2RvHe`TOa=gX98Mf6%&)F3OmaMCyz2vcH5-?I;qXf%s$R~ zn&D@Q$flZOnI!?fb)DxdO))S8!peU=zPnyp7j~2h%6#mleZHA-znwF6IqFXi`hGmF zk@WUdg61*9V}ke=Bq72aIS^+TXfKqrbb!aM78SYbdn;;v9bi&hGk*@Myx@@#hySx~ z%}cA6gFzJG`9V4=mbV3SIS8@yjcxz{_=2fU#qn6TtG92j{)-y8HIe^Y$hZ1VA&>cY zDy?wK(uv?CWI@)eSovnx993reu2-?Vz-<))*X9baIx1=K58Me%-_7_baL)P3(g+xY zlCtv44WudBcBZFpi6EE6x)qCva}+$<$^c}OE#$+fmZUS9Mi?mf^E zuReS)?cyfg)IdDO4SOeRv(IS8V@ceHwY3)OdMi&}5d+Zs#Zqhyz2Ya&ZTk0n##O{GY?uIf21@JSI(n8cAfFim5Dtl>Y+NMA=;m%!Z^ys7WHY zHe5A{{5&?JbFOuUYmSWTD<&Wz(e$yyPde7dx*M4y)ngdD(vB zP9fLfYtt|ZxIX{bME4&Z;wG06iyMd=^aE_B9sS<47zGAN?E3diuAwt32VV)#=D~TP zuEadHfb5JT-PGf~onCQk%Pkyhs2da1g9Myf#>wE3;#k;z^dtP6&^r?*=mI+w@D#HG zo$W<$r!)mM=$FlzhKJIEs0+D9oL(LIXQOg;aSb^}e!VuSfWSWpi*|(Uk-gIgBGXtz)ms1d~qMd$&P2hkU z3eKP^*)zZaO}dn1}Jco6U_m5>!!YKcH z*R5{=n|g+s#vAj%wL*Fap43b1HNC`lZq5fBBgH*wx}Q-1V!K_yGH>xNovOwYygzhT zTvkq>JJ}Sc>V5#{DnpZm)GsOd?rtnnPyI&}=*F&IN?KZ32b4Ct9qyd?Lzm{iXwEED z44qIX>g96e{nG$6aXl58iS;5ecR_!B#eZ#zls{jX!n8}x5{dcLD+SYoAnZ7a0m>uJQ95tb?uF&f{_Ve!*KO^fNRl)LDQ;rr1tuSNLp7s3#+6RhN~4reyC^u0PiA4zN0^ar;T!iwG_?>A>{=tfUe zl%GTP%WeBA7^b0pSAoWzvF@1;8`D7Z9ck*@XNdq;&rnswvGdV=^i>2ou_8UJ|@wJ#dadzwPrVCJqxg_@0-`YeVygNYWnIQ zL?_k_CV&d=+I{DdW1f<5Z6OQS9=VJOBk;GLovtvG7!ed#|TuEXZ=@lK!W;+;_t~xS>fwg2jG%duM znq$W)QY&JU5HoSHBwjq&AOU2283x@heTXiNirT%q#PF19edhf5YG}x3dzMM}J zIrqVc8Gs_m556!}u|NgeDmgrsx)D&0i2KyoA8Dw&y3a3X2%Ety+T| z60q!YX5cLc*0?~X1I4;V{c~g97USgO=Mg&WMjnHe6|s7e)k3DEL9PIS_AbZmC9@yd zS7>LT--$(P3NwmCAC5H+`yyf-M&Sv4;sU24^JvdVLX@I>Cu5bB&cb;Rzoh~KjvL_B ze8w0T@G`?PE-4D-(Yd}=0le}7$lZ6%o?TfMAl4iFt({$+Is_5j(y%gs@afd0Iv$;- zBc85qJu>&eA-G(gakUco4Sn_h*F-`qVooGi$Nd3N^r6pqT9RPD9I$d6nf zR8RSKAS#m8Mh~zC3RbfMxyo_23TY#W8A7y<($oqz53NBbY=p>I{Rc$9M=|udy~94H z-~$#pVN_T?Gd(W9fs0%HUI@LugZL~cEf6k<>q3@J5Uchh7l*w~b66;3Zvwj))rK#h8`r%+5wY*Z*|G0>WI0YpM@LnN;PUiC*#Y=4);?3t#!5=Gq~rqHMSdiz z8+c}|0K4a{t4z+)q`}uR0#b8HizVV5fVe!{mK9>#my7(t+r=j|T+LRpPqeu3!`;vG zV=d8FgWZ;3eB<@(XzPWLZ$pT~JCG1KFefE0XsjaoOJF^k^Qa!YY*OCb-*svF?}A!M z`woBfg$WLwWyG_FVg_J?*gTR&PsIR_T&A|5OhiHb8EP`rrVFn_OG4HygTPgW3E+Lp z>J=r^c=cpM3c8rt9rlACIAPGfj;oyN`^U0;Pl>X8nlG zI7B>HP-JsDTG4X>>5*1V;!`dmL(qp0cWA}R-2l;pmO}Jx2;fIyVPO5?u>`JM%HiYq z1A&0iPT!L)Ksrsp#Ee)D6DAS{GS-t@08b;OJp-d=tcMpycHZ@z(dOx&-CZ4&6DFk$|?WFTd3kXM_SIyME|{lfHLX^^9`_F+E>DW#yKA!hz? z{?B##jetpJM!<Qb2~f;Z!Je9Xp3SeS8)IRJ-?Z2tHQj<>hp6ZF7>A zR6F?H{Rh>4$nqH|nGQI@#3dRU!?!1)04|07^bbTVvg%;j^+3v$G*^?&db#S1@Ib}HE4JJ^z+knR+Gjh4 zZmJ$6>r)7r0)`Xp#K|(43(n4Q_-o7d-Fq zj6P^QA{-VEG={*rp&IKg#Ufg!?{uT7gA3iA+VxyvHXpp!@>$H&4mJ@|7CKnM#3Q@c zu=b8n_(4)78>>N_ZhynybKz9*jzwx=RbI1>NI@C3Z~L{d=_^XWT))KS98nMuMGO-Go>CHSGH@`ku z9Qat&okLZ_M^+KXJql@PbZ6(zf6;xjh^LIbO{{R{vZE1C~{6 z#1;-|8G1(=bnF~=Vk=N*TD$7cXf^EO`<99LaTV}(Ct?WkNdv%qBfvm8fIo!VdxfSf|Ord)+!~ zk;M{WK6t=F05>{h(2%&ebqv~TBFKZBfW{m4%@L9iHk_@jB z4dR~EI#tJ0#T$gS|Bw;ew;==-&%ltZSpA)k;JZAxf zo^Xp$u7%G-(@>CETHPnJcnzdk01Px!wC_zidk-E}(&SznNIHR|4gJYv^oUThSt0bOs~W*)+DO_#q+_ z)6>5{{YsQ+C}g1&QZ%-i6nFfDQsF4VnL3o!KpDrbF8_`^M=}NGIc8OX{KONU@(GMf zk`%C;daTDg6JG7aH6D^WGG_qsm?2Q|X3{aC7R)GYCKq4v!2dsr0T%JX0$!P77m0sB ze*e*OJl1Nu+JRO>dq-y#l~?u}0c!NgM@-~UPAvU{?HeSVp;&HxW!|;=@fj(m6*L-2 zH79(&%})E`{c|UeOd=%bNC`HA49y>8nfYF9)21cZCrCqylD}kHv! z000vMOCs-Q#R5UNuROm@x6lD>9O60&=1Evb!D3b+52W;Xu=~hT3g~18f6Ge5sB(Z&4rmjJ)fSZPjeRqqNPPwd_zaO}7=T;?v<_qd zNGMI1BM@5U8X5U|jX2{@#;tQc#3~EGe9fLR`u}YnYwmoi650{Xq((XL{o=2=im=&Y zyk84>*x_f_o?JKhq(LTsJ8p_g+L4}9()P8n6Q&_x1n5pqVFICj-%N3qpb8GN`^?mN zVtWp5n&NzbB_AAr-1gob8<5EUTde&w!=u7D1E-4E$&y~oa~fVhL^V&A#v-XLM9D*T z;SjGkV4obL0Ke4$tpOi>aHdxa$C9ut$jlda)m(hTp}WSNsXg<>n&orE51Paq*LAQ@ z&Yyeds@RJ^uPMr>>0YO0U;-d2o1vWe5W@l<;Z>2O^4DB9YlO3$t3a@;s4?by!bwiH58>FZCReXgcN{{9ztre z*D{!Q4>ZR!E_*OHnOlHmmU&0AF~@}E;%j5odbRWFdfJr#Few4oTvw;#AgdsJ2$F3t z7}&i-P@#{eDTp{i&dy2&L?i(Os#6?7B;v}A=d*+1OlMz#eK9R{)>`NgcFCaW-aHpq1bORsnGw%^;$*z%O2}>CE>UHrUv_$y`Ox)QOtzCOrD)BI zf#EXD4mYSN*K4JXP;sTI6&Q)26glzqrmwFDDn{(#%&Wx*!HC7KLk+}I<~3`sq4g^1 zCOt4ly~MAAkd1)R=h(QCqR84h_2saonq#3moUnB7e4Ec1q?uj2cI^#Oc6yXtcT7YX zm5%u+wRo6x=<1S;969(Yk_g7||r89xf{Ahzh{dt$EGLgQF(#xEN40D4Jwf z3@HfEl2-P(J(4Nxe_5U=FpqCII!C~x5JHW~@t@s$LBZoeEmVmzgoq%JfCRww1ph<@ zYzaa)1A;!+K=T(%_$2tx<9R_REME9OL&nt3vD|}KO;XrutV%JBMTJLdQ>T`v$8=X2 z%JAGOMy32@>X$KDE*U9lR|ep`e?{c>(+#rnS13tFc8|BuVsKWjPFn;cv4gWdR{kbl zCt3mJqIPTq#>ydO&>XeF8YRFw9&GM;sV>V=A6 z6;$_1(fYwfm548QQNy@Ys~$#MO~fI<>1w*E2z@=no}(8~}U5i$>v3{58Q5A^VN4@Dbo+)x!pukR{Y^ znJa-yq&E+2l#}cD{v=z+)CNw~pU@l{t`d}fIBKG%l8kZLcOG#6lm2Z{OD15q(3W#=cNbu9EWn3WvM4$1uMb}ak5Dkbo} zU2XOayRDp7CL|IHi(j$;@6HUtu^ zXs03iF*3alvaGMP1_(kqp}eV303(4O2g^CrgP^@6rC}#284;1j8;_?iA3_)8Vl<(^ z&YzR_^ma*sh?{KkF6xNo_}x9IydSY3#68Cy&b;&RjS(KpLJD%s@m9&`YbK)2?pdp# z)SCy@58Ybd5{*_n0pXrx8Hd4kX^@V}Yik!HfvQ0&cmTH|lhHFO@aI_hGoaSTTb@jm-Q5qC z5kVFyrT@pIO%7F(r-`TNMw2m)OZ5nAV~g_yWh3Ce$5y^>d1LfZFN;#=pC_q->A8=1DOO;1l7o9?yY7D8&s9*wOA|AGUo#a82LEcCUjls zYtF93fykoH!^q-s&6r)%Beb=umZ$9Xq|zIg`CpyQf&H5sw(lYVEkkGJ?l(u9?j$c? z86fSiv9zM^jc1~b!i%|1oF!;4bC+Y4u0SYUV6IWvsNw!ZL`*&R6MV8NPUhp;X&}51* zqUTds4-|ZUQ%;ab$a_~Z4^hOh(SHLc&0P$|iDfV`48jeoS50fwm(l=2PMvO1V2Ab1{RK0&5!Bs73fic#!cj2cOC=&v!{LzkjL-$gCl9M|OJ(W<4+n;Ccg?2qY; z*m+)*OmK;06K6yX#lwL_brwP#1)y392?ZH|k@bFv1|fU)?Q3%1tnC-rdu%>xSLsz9 z5#$P+R}EBkx0IAUKv_ukgkkt_0DFJ2a*ZV1N671>u|+uM+Uc|eG>L~mt}R*9FsfQ1 z!53c3y?Txa$zru3;$!8yP%_{7CoF`_`TsseZc6{@to#j+G#;rt^h zFFS?KnTgLHb(PSFvvA;>ue}TNHu08=Vcw}JzJkK1mr|z9OhaW0V@7J{ouEP`c_8ni z*ocV7cuUtWu$npt4a#`EOD8?lW4;h6e6(qgR0mpA0OX_s3=-VWqn@PWc}Y{Udzbwo zDM#w-ZsxyqxHFX+Y%zbmHH2vpk&lw)!LWn*g78OZj$jbOJ8~(U|8%oWxv9_PY^(#D zO3?f-WhrEqrl$z7>d;(gO{$ol{hpz2c+_M?9071q4UxSflu~r@!=Oa)A++E=s36tD zyT%;nv08}DUe19@fG<3^CT#YaSLr3L=W6(>m8LlR^X^rm>_BHh76${-ek;0iHw%S$ z_G51`H-$9sEx1`?Y4$HW=b<>3puWJx_A3I?s4qXM1tl%G(U6tcJPoh5#7xKg*Carf zI7D9}d!0So-k$*)Q)9d%o7j9O9C{EF=C5c%0f3bd?eW+j5W4`?x08dXE!<@Pgk}EV zs&r{|l|>paOQUu%d3rMtVoax^FKcz8+HsNp2OOa%qioD;0g^d_(bwVk}~^wzAs5DpN|2)CVqrB zmL%CQx3ev0Cw>s{h-*N0aWp6j(D>X3iVsRVF`=Qd?@{B-K2fIb$p0p_K_`P2+Y56w z3D1uTx@vGoi3|l*>DaAv$b=YYKODbTUA^(FPb14SYA9a(2B&r-^z5yXM)m{3-QiLgUY`vwX}F6_oxuMtWT z_yCEl-h+WM@&69+;-JMk1I^T9l#gR{0p2qNYd|-$y>$?VV|e+UfU#yzk5bb1#AW|t z@xf<2(Sx; zXizHp?CBNKbos+0m6d198-bk_lc+2tv;ei46`%FL&%#?IhV{D+MBE)fhFJ&zy zfIwkFV*i3g7sC7epSp6$r@ek%2@U3t?@c0zRDenpHf(Rt61qNnRCLBA*ePjV4D;)@ z;4SfsoB&Fa#=p7^-`F7v6E)09z}VV~i-%}YE213$z)UD)4?p9TSw>WiG+rqP99Sk_ z=JO=S_n6ns=g(CZ;*e=3>5zUCw5YAE_jHrjpG35?9gzcSjv^LlZIH+z(a49&mP-pq zUv6IbuN;1fKo6{%kO&E(w&eWo7Jm%$S1elj2ARN65wZs83a}6kIN*wat%2*sdn1+` z1WB*B_Pk#rc+GPA^d=iX{{%fJtJN$cUxY=M?ImOKCfKGMVdPO4rY&?K+Jx1EVWW*D&8aYc#m@~ zVzH12lwJw~L8WbNk=3g4Tf*Zr1(ro^#0$Fqv5l5dqMHVXff3&Uzki#5Tz2} za)K)RYIOX0o>MYx0r4uYH@uviwT~flO*9=k(uw3vHUgU#Uoao~pe-5SnC1{Va#avA zom;Ql`@7riRHVaure1Ynxe4_5_-PJE{De}ta)YJnRvMM{HVcpM{v##s=``y3$zW=z-$K!llSDl5=dmOLhc&-B^;0a6~Bo9V}`9|;8 zjmcc=mj8XA!#-q*;^@1EdrzQQvP}-@B+!Q)`m$%ua@eXaj@eY*=A1BD@Y7j36O7k_ zGzyRrwq{h>m>9lKHFe;!Ya$UEfczsJz*!IEavp-w9NYx4L2`=~qHJ7!_KX*9KRVC3 ziPi7lzo$^JfcQ9mFtS|ry0>@Nh%l2PPoO!b;bX~0|2B!F+SLGyOP)Uh)BpgRU?=Qm zOC<@f*k2Kk77>Ae`3~6z3!*6?=B>o#B+7%VE|A(AC?v_X6LRnD|C^@UL~s0*B#$S* z_1}~VoyZS6Ey2{Bo>!ajlJ<*ti82g%e6lG2=u+cz>p`&2!_zqz>O`O8lal8Uxc(CR zx`6_&_04-q)GdVtj_Q<XHTk*&3UaIeiG+D z@lF61La!~E?rCB1^tk5-PfP*rGf{AP5abzD_6eTQcz{VJhL{D&>8(=WWJ%N?Cw=IS zBP$~V8r1?O5$X=_6S2qO3!6Z7Bx`an(x`+E3=j)-lM{f=B9AngH2(X(+JA8@JAJ1E z`w);9qA|bXI9a_m!nENqZC%!=DZSr-_)}6=It=~|WnFb|X6h@PMI?NL_-EmRye1j9 zIQIeTvAj_*7%h?rfSaDQv@y^c`U6H#nt}z32m&OUSTd#`JpYn%mw|s3AS^;vqX!RM zEB&RAgi`^X6gff_r$Eq@i;XPA?*@vDgg%6figR$h{$i)tZwR+pCLhz2AA9ILlIOL# zsqcw6Kiuyh{9tSP$~RPnes=?_jXesAoz)ih*eOz?_6C zLSCAmh$sk%h#&zqz2M^aa~=$EK0N@L}Lna^qBVjEL1=8s1nx-8Fq2K zaf4O@>99qg4%n9M!yr5A7JzONv2zf{JkqM8pSq^QQwbGU&OBhot2hmg@q`)zrX*`o z7k&>NfKYWaQ6vEo1-9B;`1%~>=;9c??C_DtRP z{cTy<^!SuOw$;q3Ez-dQ*oayI{a{Moei-PX9Tkdrsts~Hghw|e52u2@VBJX@#am^Tm&^4=#yv zg~T)fE^wVZNU9W?`PWDaA|5d4_;$go0}w3?$PN5E7ww`+Gl6;tuU>4_o>N(4&fS+* z1M@Ej4c+3dM>ZDX`xt=C09~2y3%CG9;29uOQno2%iyyvd4+zHQ`U?jw-tm&%@_14! z&?3Q2Sl}=g4Oxr6I@CzFP+bi{h?oTWl9vzaL$Jmh@F7RO6X6IE&!TbznDZpjwG`r4 zkd%^&1FS|o=Cc#RJ}jZf6KP_D)+x)b;M9B%MBwUuUJ@p;&=g!`5M2P(?RY zGS>;7n!}SM;_dJgm1xf!i_GGRKV1Ez#rV6#E_`);+VTC4;z4?Doz3zk4~$w~PD*6bj;wa0tgK$DV+jhLi)s zm!jF|LE>n(*~NCWVy1z<6mm+FmjgWsPBpNF?HqhpvcHw%1$*3TMmi+?ki?=z{yGj% z^5M~0Y~wwDrJAdPM6;o2;Vo%=@*7VB8=g860T5Y#N1dwzZR>Q10Bov@vmX};BZjZ> z8I2F1&B}ih)~07ZB219-5~T+XbV0a)`4)`r6xg1yt$>+`LXSt|?LrT;ozl#Fnv^{v zlEKMh84VR2Q9(fBMq~{L=s7CtGEN~@(b8)xGXN*=KeD@WGt~2~S5((MX8$V(GTx{p zT zS!5-$@bYfP^jLx52IG5@sOYUj(bP~HETh2QdFpKJ>PPb*T3}N`6x&w9;p03rZM7{- zMZZcs7|^*9c_>)#vEtc{FlLnkD%&t%`L;IT%fP@*B&Ly#g(P^L%)S`Tdx?n_lWz8W zx=#KzaEfMwg_qAl_ro2%f)#A>O;qx)-20ScC=?J`*?rUadN|&30cZl(hD-Z1*(<8l zQ1T|%fGregq3#)#+yN}O%uo&}$U18jf^_|(> zy-5xCB2sueMh7^gd#WTU@t*RzXY<<{7Z`PQoHG|h$q3a{|&v;>_ zDS_*VfEMvBz)(r|wI4PcvmEWNA8_?Zp*MO4^C!_GqU)(7p@#rq6(R*U9P7MypJb(> zYxNx+8+jC)Q@cSvS@yn}!93OrHWs-#$Gf{{UCnmka~KiN~q*+ za`M^@T+mBi@x5K#RsHXiW9roK9d3!LFLR(^)?Q;F3He0jOdy=yDXq%T@wfg9xSo31 z7As%I>c49$Qh=jNoKc`k-(Gloyz(Pgj$PVUGy`(x^f^WWt=FW+yYx)DZA zxF?*{6NrH8K~e3@X9)}p%z#^hZ1sdTE7R;OdC0MH0)~%`6!Jz8$q>Gw>7dnH^^KUf zN7|ueTBV_VK(CNO>8p#>EZFZn`X17zawy&>=6>0e)`&1{=vOB2HuPd^2Jfw};Pg)K zQD~`wKM@CQIm9z$l!UN666TR`O#4;H#D3K$Wd#= ze?mkNmW>9z zgrC}msC5+lRbttsJgPW@_Lks|P!b)6of{Vm%|EDcxV11d^PALCC{tq}&mHJFQ{qAT z?=?Vv&Okv8@QAdNWbFa&aJFkS@%+b1c$1C5MDB$=x>qQ}37WrtK%GXGy`tM97GmN1 zWkQ)unB8Z zvikBX)iYD%#HNHP7K(i4Uy-16xPM7{?$A#zj&s{rlKz-X;7D|g4gF*T@XXK}k zgFy@~JHjg)Ttf0Y2Sx3u^rVe^Vef@9nxW_v)5L7;!{}X1LdEp`7k$)k%$VQSzW!G> zdP~P|ZhDbx+20bWuONJ1;tM;|$zu=KDw!_fWz+%o0Y%JevU{mx)c+!kDZ~||%n2Bn zJuYl-*@>l-h5W%g%PGg$m#NDo+~{8Fo`Q4__Q-4uND+F(_18ENu>L+PchBRXHGePj z4f01~)r(;#j9c?p94^Tt9cZI2vy$_V$FDC_hpXo~$`VPeCjQAZoqU!x%gLOR>G}WU z+PFIdoaIcW(b|;LsXpj^Vqm8r(-E3?eY>3yegjE5fq4!hh}-B!m_%I~jTXTw#$nPJ zM?xmLQ)MFKK9jK=5QaN*rq^5b54$w}OKdc=F)cgx7i!YjAjo!BB(X7#v zm@T^o|HvG1X@z>iQvQQ~<*UCs2>l2=w4_8y-QA|2I|fUx4K~M;fee?4xHw*wl@VH= z40kAm%|ko?i?#tYdyR9SWULN%7crvJS0GJu>bLm+>hvJ8UBH$^bV^iqL`!h+po(=A z`9GK%$Os3~222%~ABeecSByAR3SNYomNxM4o;PbbVfp>*fRId+=J>+K)cSW?;tG`$ z&b`>Q9Db86v$x}J0R$wK13nQ~ON@cuD9|XpQz7YpvlmcAO+)kkSEz2}I|66d!gXH> zI}H(XjeI9@w-8ObK-huks|zFF@o4OZO9D=2k}%|y|F451b*tAtp1RB3K1c@%0B+v?%5|c#ThYve^w+p<&cu00egDM^_#`<(pSEB}LdsbU7(*tPLLlw78rkyf}#aRn+&{0OWMq*cdb zq>PA9A_PI7{#Z3(-6PfHriB+>IobYGGof+`q)P$aup&X#G(|yOFP!o->mI>At2Fwdv8b!N03a%U1j;hk`R1;V z%}2hDq-#m1o@T6_4b$E0Xce*sJTU`@-4eWgi_@(os!S!}^B=QwnvegJ7h#z6<7)(h z{ZXjjys=RjRZHXF#34ydP{7i;hj3{b@9}$XRXzP*W&b}1jUFU>U(1XHd$H!)d55>F zGD{xVb;rlW5tT9gOysrGD)w;V2$thMO>ZLA8q2cgUbT^lr+DUz06wpUdntcQsK(*U zS1R^pjc<9~h~waYd&L%!3TjyIdB zmq8@^{5-jp#{KnCjN)_S|G>>;)tHOQ(1+iEam2L-wG~=2r}Kc*NDd3q)eN8!RSXc+g2^G#KJn zfEh?<0ieGZVaBI)xjT^B1xrEaoY2FJY0TJnh>Of){x>K3HE;fiz5vc z(%0$V7o699*c1V^@EtZ)T-vI9l5~ZE{{Z6B;_^uBEtw&)e&^ydKg=np+kn1&Pyx+C zE6dACc|>GIQqwfTTdu= zB4{VdZ?w&Zs93^?VZ@#>)0SWOq)afZ8MnOn3N{5muvKJW9D{E3tq0yjQ2SRml7h(s z3O#I;AQ zB(8;#?qrKgdxPn1iuv}UnecJ>MfONfqvPlkxbYq0QZp~XmTE6#9ohyn@~~! zz}vbSx`rAe1mtJ)YyeE;NHm+A{2p|RPbv-mf-QAZ|7N%i*CxDqgNHar2dkqdD9>T#d@ALTyP)ha<;oTrT1&*d0 zxaB4@d+oQ~_+57|g$8DFHZ*cUDskGreG^0Xeh)>ah)j=EDl?a?J6XQ+xhMa&Yvl(w zZAI)6aVcY(B9iw2JZBWsvSfVUZkE_2^UvkCi8ldw>2>tlpD^tZZyiWr-x*I3sA#j| z^7j7=4+}9GV%-Zw0;EkwQ}GXl==Jen_F!my3H=7qFpk_y!32H+m#Eox3tuHJ*Bug1 zLLur?$Z6LQnLn`I2WcZD+#kSVn)~!OdXiZ}01`qPF)wzJxoh<%W)`$Qgy979wd~le zIQ8KJMc3p1YW*}xONSr%%h8h?D0qBY8CPip+Bp2bcuKCK&D>_;iLXZJ+0*HIzCioS z;N&DxStNo7EYUoC;FC?)pz(jh~JAO z(2)kyAqvlC50QW%2OAzs~OwA<<(gTe29p<@H1@`Hk%4!QAf z_r#Gmjf74#p#+UxWr^m}q}L^~q@<+#M})p*w3wo2X5_60(K@?VI()wiLXc+lhOB=* zAv0yySr1F5XkU=g@OXzxYPatG=arV)O7r7~$K~{2NbnqH6Z@5QC^2surhdH3OB*tX z1;}%8VM{yV=kW}X5ZXf)?>2+_seriQF2_W3nW_OF7!TN+UtL-D57mGtF(MP(0la{= zf-u}X)Bt~4&k92QLKeV$`ngXdC8$3BS(G}L04!4#`|n)6mFcIG5{Sc-Q-kaI1pm_V zs(>2(R{=~Xl3fwRGf~EYgd4SIj$fKJOPHj;|8(T&`Y~$0x%! zfc^@53^Gz^)HESY16bz=B6C}|E*8+}c4Y>?yi{^KlRaO3U7-iZ$<57D1-5}TI~eIy zRwcKZU6$}6m)C8TOlmuRamUogeJ(hWIV%mwXNl!^_D)WEPkzJAW=MIf z&#m<`^V5%-1oJjJ?${h0OHMER94A1_anI~HY{!P2$!8{Y+ebcf5%mH18z00tZG#8o zB#{x*Bq36U(3=Fd!Mu2~^qO^wDCY7t6xT^k69{d52s8?(3u0`8oEE|G0o=j*4T9cL za9Y*ep9j!?Z^)yYc16)wfm@lZ;6LxBpWY?=Fp$+APM5k1FJ;*E%Ms!QRAnhjGd>dubu@U0 zGA6ZR8N7^S$t)1Q>ihGw0hWEL^IaR|SX;X!J7itn9xO8I?XvX!rXk&RzGjmUL)jbI zI3ZM--3y|=MF9t0n0_$bSl=(*WwK%O`)45)KP05u>dZR!whdRcP1dyEs1hX551!&r zjt^D1dN-MpD|G6f%)KXS?%tcCB2-Ag6*QH^F9cYiBjzZCFod*6ZUaemLz-zhk1cUc zlCuhgASrt!5KQtm4T(>H;QAgUUj@VM571@0`Kn2j?&gj|p3KYi+j>heRZm7nDX zcNIVz7M6FJ79TrR*J-hb_Cb-uj>-m|@*w}_0al@q&Ffr~R&9AYS33M6QBO!-HnH~9 z3FAItFUGda{5s~!`%RM<_UfITG5@Ol(?Rl_v0151`sN(Ue6Prg#7`?1#|^ruWT>wUJi;g*o0sy4%edCZio zu0y$2UY!~f^KFYg#y(%aE!?m0`7}6IBIhjHqf)u~dn;F9SCoLMOUUU{+vTsAWp8v{ z9ZP$&$G>b(v*W_k!#@T(9SjAWc7%6csNtC2{KCdpv2rhP&Qad` za%bq>?6dY8!d-QT9T%7aF6Ryi@=g4}skB*U))_D~Rx zsJzZ-UdeJ_#scvWt3|T;+4NIwR}1ohqT@gHqPa={qKrk z>9CI$@esLEY1^Pucv1fet#U5YoqRU6CyCR`mc85Q_`Xk2G1Xi)M{e2A?y933?FW-o zKA1gs{2~6{eY~pA$y_7#O1))7N6^cx{<+&`)cjl01R0cCefn({qK7#0Ol$J(fCNukx9(9&8Q&R@Zb;EvvSS9LigW%SFp+AR@+>D zs%Fvsj=t=pbZ^FPKObDn(KTXAZjY1XtG4DB0}8($i(kW=_;uH+v+hSS51e{9zvn~x zodM?v$=x_pGxuDO7xoaHG~eGaDLOKg#C+z6wOmNW$&S^uCkX_u_4<1JCEN-VBx5}V zN}HhRBqKc9@tv{rqhf(6^W~ny#fzW$gJ!s`H0<8A`8>MIwrG)h?@g6W=*;8pEOqV+ z^7qSC4QkrlwdIZlALzwOO;5QWPTcI_wG$IOx4e}ec4nt}yD?7_(d zddj@piO_@NF|jN zdua?e>!!W$?wIr~=V01<)-LyYulKIWjMuExlyO-DkIN!%2LGRzZqfFDb|v#l%Ata7 z{;jcro0U7ay0I=CihPiww--x&b@@B=1fK1$KPUKfUA4?BDqh{Sr*lqdJt#4Fv5wv0 zv-iCo(|QQeSB@V!apg zu3+)YZM|2+qh8X&j-z~jOyY+3cw4e}pkLT9q|Sf%dSX7boV8~&^~S1y)xJy0H~jUr z-HT)2qm9gmO$ATO-(_c9cHT2vKTf6h!EHw!jc(EWgQI0)uD3t69UnX2-n~Xg@E}`! zf>C@C&Hc2X^9Gp*4RyY1bE;(Xj?VuOmjAIgv?fNGGs_`5!Dnk})NG=+c5bi^hu4Yl zoXYz`&OC*@{EL=Os=Ot89!4;z2HW}F6`W2-E zj+Xru=hn}z zp+t2ZDtjthcF?@p=55-RTR8G!oN|t#gfh|pBiB`^TXO%TrC#1Kk6tw%^$q(kI2CP7 zjd)k6WIrdjCGXOQ6nfuEi?3&H&Wu>_QCpbmp3O1uuHxx(EO*<>D%9)!W907Mozv3O za{bEhD;Zn&1(yZ+|2BLhJwIU;-?_%>OwAY9(hcVNJyT6m?K`^jvO3%&_)V+FO0ztb&7%TW3w)s2FcSw;T6%{lSM4GaHrs_tkCEOulSN$8+Z8&&xqq`-fF` zF(@Yn>CPN}d|`F@1JMgDPL|>RnU0MjVZwz6?%&r7dvLZzefgUto8#}@&%6FCTmI|l zL~RxC$b7hFI?Fv z?>6L1=YFow{q~98+DT0dK?b!1il@$tE#?j%oEckhO71?W#URulzwr26Jl~Ae>rW^} zbUZ;jEh7bv;W_S{?$*;xJsl~!8aJn>h1#9karkbvWklzNMiGH$=leHF<()nFa_8c< z!&1IZr5k)-o7LGae<>o+^3h}{kDhH-Af*0 zx34r9eIGe{c6R}-q~%wtP2$+g)Cw#4KVTpNPB`ScsC2n%cj^qPOe}_`lCTvh_`{ir zhvt>)Zmv4T51e|B@bWeq?UcTJGiIv#TRUfwrzRO0PibH25R;``k!rs`{;iw;&XDQ- ze{#F4@`62+4Q>?~9-%hlI$gN=DzCuLg<=45%vUB(bStGB9eSgyKI+x0K$?wD-kBf({Ua*Qh*6Fsl z?}>(%$AXFCb}K#}`N2w^f@GJcru>e~`-{3m>aAVRnadBE^mS%W&xLngz|9Q%uHs2K zhHI+WEu1IaRoEYCacaE73mIGZy$v)c89pSL6ns^zMek7}annz*@-4ZABl8i){bH$D zymqyyJ~|atnbO$?Ybm|0ZiR_vIi14|jBZ8kPcP|uV3l{nXIhoo%C1j8?y!!Y?VRLa z0Y~!nEOw@@Q)^3$ubFE2-C3tu_OA43=6;?xVqu(d#C;FjZXJ@w`2L~wGuptx6je#& zC$KaS+xTEZv@}3ak|qN#POHcl%pekNBQ2E1X71X~T<`6^C7xkd`s3$0n}d{NWkc06 zMYASV)oCNwRO0P^IefIq(|olb$1lge(pvU4k~tT1zrAU5H+gj8jZWN}XWX&|iUu_i z?xR0PD9n{dr6=NUq@G#`4-7C;G2sv0mOa3v)tSpJ`(oH(;i>-U5xZQM9kx~iv+wf< zW1OP>PSpyVXL9ZnA3QvCIOeW>&g{yLsI204vTqn+;8Q?|0y5drrR9e4?O44&QP*{! zT-)ngV!HqPXSdJUv{^y;$f`dd-y54Ap*kj!h>%P7;-d@hVe*2i1xjnLnLi3b2A`G= zT0hpr+KTzC_HlEY&O-AIR%{Zsvx=`WP_mys5Nb_4GaStOPjmEhTlMMi)K;C4cl$CD zdcJu^U&vdb=hXkLT}5hw|Dkc(%wc~fuXbuJ);;V@a9kh0uDYioO8Vh+V8_3jhj1F~ zs>GZMTckc_{jejgxsSF{$ckl{e8acu2!0=*P;;5Gl(zV&Y+HTE_dow{#bT&M% z4*$Uca4Rzui&Y4y=eMzp3ww?3)${|Kl?Zt%4yZx>=^1M5O-`Jr*6M}2hye6N(M%9d z2&wHIWV901hFBi!zj&3QHWc%8BS0SfGYpW#ZcYEtrYJ}V|5arKRUf*&=l?%X**lWl z>c2zeCbkE?LZk_ZUSKMGVza~iqV8t%o39gh4R~sj==7xBR)C12ur(n&-75`<7zt68 zI%ghN3&>>)_Wf>*cAsHbMYWBBgo3*}#q~t;p6oM6f!W@c96PVa{w2@!CgZ>1T*Ln1 z%u63ALREkzP<96+6;VSz1;$M_y%AT(hE5W&jo)BpXFqBE3pxPg7jIOEIzkWVRst^0 zY$7M91WHUWJ33C~W7dSH+=N+O=$#UFv--}!mkd$iNub(*xyjZ)^Ioz+2+HxzFp?@E z(HGGgWNYsh?2ZLx^+4EmxoUEMNi1G}Z@|U?0xJy8U1W(B;P2qqXk>Pdn}pkX1=F(8^xxy1H)r+qYuRj~~9n-A>)q1%**4 z@*p7a_gdZg#TWWwY42k^SMU-^OirFjR^Og|`16l&)#U36Ced$M;Bnj>kr(o34_TE5 z*(*Z?7O);4hnXy=X6I4Y>p)jy5v5InPo~=fWY~z3l^~x0%|9b6@gE|(`ooFtS>jC7 z!4}8DbRyN;m}=x&a1NR<6{J@ZLKG{tZzyn-AvzR9CrKXy%Djw1*fOZRmV&e48=R;Y z8QAAVk_;ijWW{#DS6X(=K-j>S)HhFqpY$aRD$;RWz^z$y3pI6-=zH*0<|QOZ41!f< zDBOq!{iFMj9VCkmi;#qQ>zk1*MHR$ zlq(=mm*jkJ>IKZ($cy|~J(-5Q2p|S8{+OKk4XC(zV{%uqyPKu>2X_d`Hxqd#s(-l6 zqx~w`7gZd-y=Bkzd&2a1(V%z5#-SOlFHE;u_KJcDPp#hu)YIn@2=G?dD z>gLDj&REa!OMdXJ=w{T+0NdluzGLO_V}EwZ6E%P5w>;pvL@l~%?NMDU7bD@u1PdnV zchDOUWg?JX%j`wyIVKQ9P-G0z`Bh{tc5K7iOtRJ)-4|K*L21NFBBG)s2r>kY{-8n4 zz14}F6#9FUiFXM9x(2>|JJr_|WGpZcQw#Z{WS=Z~#c&m`Aw*rUV>JcJOK!rtlK3~I zF<@!jlGOvrUIuu`U2p3cF3kTpN0exkE690>h1uXIl$gBmBHXEr8N!pY`kr9l2y{=x z&qks<$QGNK!lG2XFeICYC^A3_k+4Ea_ToRL8UfJemE!$vrr)F@bN778 z0H`CdOv3hZ9=vYn^4g7C^4TXKWDoac#d;*@f5=8(oF2<4 z1m3o>d8Sbr3%ZS{LEN|V!-;@&V=!&($oGusQsg5hGnFJ6_@hRl$xm6r{PFzzv?RuZEm^9Pw*#6G=r{te_yJpTw4d8#;;` zv;n`2pM@!#tN@0TGP_!e#ZsmnQX2C>p0=Ne^Pj%yVG48ImF#tI?lx1Mn*_}W@j6#o zSNwy1viBGIJ&^6wTro;fM-rSrm=hBM&VFJIgYrF#PrdS9f`;(l1VKCF+NpaI$oyfD z4#&PFLKW>!ve~nXdGWP)&D2QZ;~=LFE7$)6p}J*|d-F+Z94y|*a$d9MC5foaFq96I zNcsKycN}b~TMV1nI8Zh47*cQBc9ECu%k{Oi&O+9>s^hC&KgHTdM4kmJ4aVQf*Wehh)HH%?uw*ndk zex@|5!d%rSEt$Kl8+2t4V;OpwQ`iTD$(e_ql#*Ib!8U-qu@c_PW!!F&+5p~m6;d*D zx=5ZIxM#?-yijD2^U8n>mX_9^V;2E}^YWh8gGPc!>?$~O zVlpRYN$4KEQBsXLx5(eF#^OYh0z|U)!xo8&pb`=<2$4(=_MCzFUewLxNK@nonjMLs z1wz$xLIy;`j=C`Vs>MheFsD zeBe%YrweSJ^Md_<6@_<5+t@nfUCI9EgKf#(sQaQn+J4(}=g{gIjnpGY95za`q;5|` z?I$vq6!l~Sr^hN35|4)Lb026Mh`$Y&jqqS_>fD7?<_wdb04tIMfCECUOVtS%bob%_ zXm+!GSr!S&Tyt>fkzS)H7cbF2S2`~32A@uaCIk&&MT}xragmlW8{2$~V;EkBS=hb+?$O1}cigubmXS;u8( zRkZv`jOGol*Iiud$J3AXT563DBO!6iAvi#5SOnRPM8E)cuJg?c@aeHwUGVc~uYW+; zMUU<`cM%$M9G0yK2oo5owvtpd3h$8v@4Lto289iY=oPGNMkjzMi`B##DktaTl#g{; zq#Q&|LR1k4qY&Kf#B!KjgNQ)bWDM+ezCwDujKcBcXsek3$(h3gaY0s8@^FCDM(O|9 zL43lQ-`K@*njtaGXZKRvK<44wETuP{benLFJU!~U-I=ymg)e&RP4(0pdRH}As2}t) znLVIpKCn{&InyGb)leCd$}a_N}XBz6aUB5gh_9iKqyqZN&p3(l7ONkUzr zRRwV>Apd#So7;T`orB^llVq1>9XqXh+aKO`TN2j2Y1e!xxuNRcaElqXYb)HjPM2L} z_OkiqrDPLR$ugCErVb9(HdmbGq*1_;wtdICga4&=ZGcG>RFDOMMEF63!XXFF`Jq?n z!AVD+XQ)1|4@6y`lJfEo`3tqfOX3nk zs&5Q##u?yf;>60uL~}@&%_~y=-+4N?UB2#AlK6eXHE?2ffmgsl~a$xuWD5P zi(LlP7$kaUKH$<=rb2>Xk<~>UkJR}Q(9v+xgXp#aEk!$k2qVn@h7`ONu{A*b5?VKT_Xq#DwKV7x!VCVXBVM5ZJXzX>YUZo899e<13K zG}oW_i(6H})s9w2IfLZ>6aECj3d|yZv!UJ`M2;4HI-QB@lBP_>36Qt(;oy4O)IWml z-FB9CJ$0$xZ`J>{Q3170%arTmRQFEQ+z#vvujhsryp9H8Qgtp96Jc;q5#;2=r++WMJ9Lc;4@l-ZtRgzLvDQe>(>Bl zkIl|QxkQ9atEGa9AT?Qyfpj?q;Xa=rfwZ@glr$xu6`zY0W9HcdQ4;U+)t(1#^H=>< z$eXQgIhCfqfo9hrReis>yGgdJZ4u zaH!HsK1sckxzj6ZQLl}sAbgxjn)!X{LYL4pzUMof=H0^EtpqEEHyVfgMmi_4GoGYw zobO@rdDpYBoKa}(^{1<2B0R?+B4I*CNO#k&ZBPl@IjrfUK;tb%qAM}Z7WG4+UIqg> z)$dv4kKqVgRYQgw=m`l93c($D6`*XogN~kYstDdg0uv!ZoGy+t4xN7(|AmcM-(sav z3r$cN9`i62Y>7QL6M9)_{yZo($&OFg(MFycAM;Q?hto*q_!?T=Qc08%4?}zm6pSLhEFRzUt&1~ z@=TZJLLeC3q5KG3%W&$Kc zGe8Bm47p*JA5jL$0+*_2adBuQSHW!wp${ug3IgfCpjZ1k67Uel7Pvgck=?Y06iS?L zpU~iz_*t!zSC zJ`6$HW5QT2vz~%2DZtGi_~=0C$Y)tok!30?$hv8dbfaJk@^DPB-uX`xfqW zy7VZfKh0+9TX%im@7&YU+DF90-4~i^yh`*3OEj}K+W5uoB`=~t=q z0IH1O(y@+&%mb`rtIzfdI9e4rTl%(cJr(JzDb(?p-0G${byJxYrHS@NqoQkQWOILZ zgzbP;L2KZlBD&1)U-=D3bV4vkrhcKBr~vDMi<&qZ|su5 zWz7lBB)1pi7SSAI06NppGM2{QbLP#@t2wT`eC=LJV3hoGZQ>K93GsQ!=+Qr)%Z;7ITx)Gn6pVO*N-> zO4NE4NckHMi4QjBI#Af8^s?^iNJO;1ew*L=z$@5uCno7^!4KtkQmg-Y>Cnu<8f>0; zpJvEP_tX=MLlYinoM-u5Z&~_U<;j2jzHQ=zTwrU!J}LKGx+P366W@8-oJ)~fo8u^( zh&vyy=##p}>`s*Un35)+MC$1z%LW+(M?7GSaYm+0&)r(Sh?QCEu4XykzM(8vbW7sX zh{g8kpCbbTbKBjxb=$Aym!E%kqr~?6#6OXK70oZ|XT(@jGDey?eE>ras_Ko^ING6i$6c07XOI8#4}6eINiynrsMbK1%Mk;Hp9thJbU-f zoqs4qqfcox`c1`Doy6BsVjD;2eW9WDK0IYEv)Ey?GM%;Nb8j#8&YcZre6)wk?!I`_ z_zVMKoUP5Z@GjM5G5`_UBEvO3sMFi!uK`oxn#9^ z9Ax>VnVrCMeX)z!iO0H+l0;j~CY0`4W2mApgMwhb(!=BzAqI>{8i3G2>V<9^bbvRar2^Z(5f>n^ za?<3&-P2}n-kL&B;^je7NNiq#icjL_@QPNTx=;Py2=okJ_cexd-;ItbO-Q1d19612 zGfl6e2ZcoiK6a9532pvr3Xy%1d`$p3L(8VGO&KX93OXKF$Fep=*)`2lWHG1?6Qq>Rwpy z@!?n~q+QGpFBS+syYem44LrL`3>!_pYV*>kXI%GpioB|kk>M^BI5I}J@3!Hl(d|Xc z)*pYF==9<4>D?*na#ImGES4$Zc|Z3)dd{x4okP<>ruE93D)u&!XIz5WL)YrW=Ni74 zZ+O|~@lmP%4tL9*&z1W^spc6(}$#&4_NhUHSeO|5mtSb~OH|xnW|sPoJAu$9VNzmG<=Ia zXQxLlwjBO)I(NBt(0<lLBI{9vzA1vGwwY7v32D6dQ>MV%%4Gm)Ax4c$d|H<0bLK2Xb#*UYD#s zaBntbW{#>S!c^T(j+_?(-WL>#FL?(&`S$D`4$`We2UxG^HZy#jFg&>yT4WCqNI{Yj zNBnt+cA!wuDth59HiVY@C-%He=T5+r0YezfG0c=J;9kk-MiyX?m#*GPW?BOMV%8l0 ze51!VBQKcy>=rnnB^?~H1$$gDw!bfM-URdsVM(3W4Pj=a#GwBKJp&YFDGr6Kc!m>} zOrZD{WV*%{g-0V0$qN%hI*H6cL4n)M5_75i*`_2kXu8DRj=(iv04XP6ks%(R#@K(z zHbETK|H#M-*GCF%`d4tjGAF~_Pa0I`N<`#WVza=0W`A~&PuOa;6``I$Oq+o9k`x@* zvJ^9#n6bi1zf$zn?b6t8PPlr2(B!gVJs8I+^4TT z#==v4^%o0fTyuL!QRqt3ammu99g3Lzk}&ZoG{MDWO#1 z&TOi+T`=DL%T6$_dDt~6TJ%tt2bceZQGWw-`V}1Rx#9AH%Vp#RBRZ0ra;Iblb8OIf z)J%kJKA>&8pJN;TEo5ZO7>f~w+YgYK2fv%D3jUL>}Y8Y8`BAY>7E|>Uhx{m zQ8hYSxyiY-+r2trmurj)w*>l{R_SP7$23m zEz547Z^;?gWWG_|6E$X`eJ9Z8xbUPF%4waq5+(kefs34Qp(!p z-AeUO=Fx`}+amCy3Ir#_YyO%Bt`?8SHcZ^JSg_{|R*(~2#MxYnCVAn+M~^oSWy@xIF8o>ze~u)ubs{ntxn$b>Xf0e+ zad}5UwV*K|t3L??GxO6MTl#>GlkA8u-~f?pN@8fx0TA*D(=@>b$H6(pq7V^0lEjL` zctt|O5%w0`@hD1wu-L%*cY|<3w%AX==NSMtiTD7`5YmdaQjpD_9k~njP3YZClMZa? z_%t?_KJw(S=Zix%qaXO~B#uj;KhHvvf4+QKN;2#0k0{R={T<=+3csPA`g`DN0e%kB z(on7d_=<%g6${vlH8LP|sh3luf3Ef&!GSuT)uy_#{#j{1c(7p9hks9PZl&mRL%wiZ z{)KS89Mu;$PhEcEqmq^8{r>W45E%s%3Mqmz8|MUv)~#6c&XjH92X3;|zT88zbJLyT z$_wr-eJ{i>i!Bqkm~XwvXV>~ViX%`#UfcEH3XJY_DUknb2(TO`FxRtjHGP$ju*2{8ip|Tjw>c$>7oRPDK(2dTVpfTJTr4 zO%Ff&VE$(7^N`iAX~pu@xUJhIyVf`P=u=s~OH1y1Tko09GI#UWP{YMY9@??h;d0ki zBYxe^rOK+z38&_$Kc{zN`+>eKDc^TcTRVkn-9FFCTzpyoazVJGuW`@NSy3%t?M>%; z2ESX+b91hFJjp`uNpGL^)%qIaH@f${mlJ6zF(F3dKNaTZ)Y!r-h1zpxeqW-xImFHs zWnQ?+&J=7z*9ox-XXbvD`Ort zNxs;%>v&fl&JOZ!)JGwc@E6%Ohv^7U3)kVP@D`(XL70SGb@{z8Tk0RHF?hAM&> zGN=zXx5*DGoL(uWSbg2Ry?g_cWzPMycG(OL@$0c4AMUtSXs;IeZYb}<>fZ}v()-2T z-y0Rm`l!d(2t1T+mHLriGjFW@s!^tEo%Nqetl{^ zi3L|!f2osS`P6TysPwN7KPo+QQN&@;geA>5LM^8HN zI;u_Hcz>CR{r=l8cGC$>hPFy|=i{m#4s~)xvhhDM+n2VIb*xDtK(U^h5}3hp`*fkF zqK)RZ#r($UWy*ZjGb5c`=>lUCXTzW0oco+AGZ*=3{)F-OK{H?Pfqnb<%gZx(rEonwF8dFuJGtL5nJkumEFSnguHH(C8=?nw8Ixi_po=jJKy=;i=5@^_Hj`Ea>d zWwM9^vyTL9hRRTkkOj+lcvn%7O{~+Ehba=l;g}LzdAg7xfIS^#t=C8;*T@j$%@5(bT<4c2GAFT6xcRyM!b-gw6G?A7Hf$A>%kM}7!d zCw^FwrW||i_{Y4^ir24~Q?Lm2Vxwl3LRJyUYsHY**-KK{A!z~odkYo)8+SUUKR}Td ztp^h-L5^Vw%fipzO(T4DHVNr^4D~1DWDl9qIJHgH1y){WKaCcUoE4N~`1T|#jtD zr~c(qJKr0Ix<~j9w(>`|AB|Kxc!rSLM2*Dem{=wi}}B$MSqJ0To^vn{L2GR>i@;un?PgPzVE^}8bpIaq-3aM zXreNdkQ5mzLm4s`q9~F=q(YRTGG$1aD^%urNQRJP4xxxLB|~P~$EElE{lC4}{?_`h zZ+&~M-TJ-t{@y&#b3gZeUFUV4$9WvbxrMKt=>iz;j?ZmxRSei!TJFBOE#OJ({Kmb% z9h&S7=cffjYBuM14h9N`nV#}Z91c0(|6O&%8fwTugg77`7jf)pA{t6h5&=_Cs73;UL`mm@lnvj_(NgpZ05EDz*HS?VTG{ z3K!&jY_*PJ@Z8}g~=zRQbu1=k+cjPBZ1Qa6y@ z<6pzl-1_=uqF9KdzLUwQR8?ZrO#YOsTb8K$tKgV3>l`($l3QXc8-gZHn$$zByh4>s zHI!02%^t4!ZaA>rv;0iC!boA5NaJP+mDCl#UU}%ZwSLX;>>8fPJUVIFc@6OHO10-K zf6RvRqarigrCe@Y?d6Q-7cF*QalwUy<+ahpnop6&z8v+%EN6^a_KvUdR`v5pg1%C0 z+jt}?_Tc>~SFe5#HtRX|M9j3g@%QzJ zn4$cdoxKg+ro6HHGxPlB#D-;GluC_-MhInRn7_F4`0+(P(^fHQ;VZRaAI;U-o}5kj z9G1Q%$0>C%-EG{`ylGqOxJCT_-U#dT!F84=6kj>jPnq&&O9q|KzbE>ErL{e0^kIAd*!4|yxsQiqCzExr$RFWxguBX_EtRn zP&T5_wYlI-^;hQF)VQ6I!bamm^!u-@_A8iLZX(Oki^dK+Xn(+-ngwWoM5O8%3Zr1X~Ykumq$emp*Q)(fPhVSEWp8bQ_z>?4P4i=JK6P=Xt9q zzR#SMf5`-wLHW~;dyh3eu)j-t(R);SWxCSZYvz2F&RpH2x#1b=uOqXc>tC5E(Jo9? zI`qiS{8{>Ty)!CFL#C1QP6sO`pIZf**kuKK*vD9ORaM>NXllmlb9OvQyZ)uz*H?X} z*NjbEYOQ-G#~gU)M_v~VKePztwpP&i5~CeBahmT?#-ht({dMa^vUIzge{5Q0H}v^w zlN9^uq?L9Z1p&j#)d^P&V!OO<-)zi&!=}!2&sV+mgH>S6u5CMx^v_RaDO(m>G~bX& zYZ5xB+GUq~n?6u!ewudw)x)pe2GtmMn51Rv8LT_# z?#0~IX2thhLz5E!5}DOu+7zze@Hr_a_JEIc&i(%U&^oCf?b$1`xqM~rIjy$nA2jas zsO~bpoOkiEgLJ`l?y@&*ArptbhkD*`)-g=+o1U4o&J~baSI^L5J32={BYV}Nt+7N! z=ZsPXV}d=?O@-UG{zt`RxPmMuGjA|1|MPp*g=A&5_zCMU0cWRv*Qyz-7ZZMp4Tj?B zjLz!|PL0mTYq1`*PWp522>%l@of9BJ10>3X+Y37Wt8+15B??p6EzVojoD+&O~JF)0Kl= z>9(!*`se>tJ~x6fliCi`R<6{qxb9b_wL%|@~Dw9nXjQ?IxZn#DN9$=!K8l4WRY`v zOO?1&^PhwYHh-P$@7F2tOeUJ=#oJkSJLRr+JbV0Q2V3{h z{2c{8hQ3-eK@GUu(cZUCj(nQy5HP&!Vt2|7JI;MR%TjJM={S#wJ&91M^JR#(9CLox za>rxrjkA9W2AqF^A=B&h^c@4SEJrW4pZ)weBDd>@(*!7e?Ck8%+N>9!y>MPvH>~29 zuiN3~l};0mpyiX@2X@Jc5WD_HD84jTp~giH9$0ZAT3A?^Btitc!U>@;CgC8`Ro3STde(J&zZ1-)a)0g99KZIaGW^)dFdLZoD z0cR<9P+;CPHp0Apy$ zMZST~!01Alb!0!91($4X<R|Oi;BOYyG*UEy4 zrC`oHYgqRW`3%i<(99I;^78UXN^cVrNRE)SuAq6r#+K4{Q>LEl}Cb_gU& zV$n_bTyU5)X|+;OQ&%U6GcjrexB0=(=z7zSPY;se6~tpzpoaeJW=@n{d3|?w9<9s?rPFy2 z_31X!R#u*W@jlFX;b&=o4J%2=+1@Ihi`$!Y|84FLG3KO%vzE&pVtLA= z54G*J%y~20b%Y~hOKjPN%}+8P^ylBN+joNm&Cb1BuU~N9=qhHg#X(Yw*U zlHb|zDg39^Ug6JDN4gx_w!hhMe(QBcN-NFuUW-GHw)mzK>a;_DCTjW`^^6R;HCvB$ ztqhm=$z?fu*<#2*`np@eeDh0du5uHb#F>b=)7#&jP1Y?ild!xuc+2aa`@@WIwRj;{ zna@F$sreU299Wa2=8wK?Ia6D{o*_b}x;$mi�HndL|EKFKRrjTizh`uFsgzA?f&t4;9a?dyuG zE`0R&pUga&UM<4Luvo3&&*^>}@7Uk&%9ScBFim#yYTd?!{2QFdm$UFkO?K?-76~q$ z+jLQVsYn)lWIpaQ;8`@^mMVWBxJF`AaK!P`0j)1AAGfq0H*rjfvdn7!?PeA5;?1wE zgPzJxO~LAV-Hx$|Pjcg({M+o#HvJjvZIX?UnV$G_PC2!_wcc$uZew%)&S~dB+vbAi z!8^0<84}ysKF&J}x)y$1tW@t8M@jjJm2gGfG0XXLd8YZ4v)PYYwo@GiK3m-x{2PAd z7_fie-eI!*jq4d178iPt|+Kt*W|X0nxKQZ0M{p9n^M;^A)a*Ta1*`CZ6r0Yje8n+Y0-kz2c0!}VomF{?)JLfrqA9VS5%Y* zTb&P zY=oL>31W$+_7h}kK3vi(#l*zWqdXik57XGMpnX2mPT?b8#Jb~bV4LW6t=ANRD$&dM z(%E?ntiyp%Np$$mM^Hb!MnTHXF+X&65Z|K)7JBejHF|fS2jgrD!520%G;{$A^~GlD zA_is2~)a}s=n$Z+~U zv!km3+$1D#hxyyB_;?{Oe|*r}ynu>srwtn=+UsWNy&B4P*_#F&qSQaL_m zpAvah%nj%hW;~P-Zhl~O-nEYPiPiSP&j!54T2*y|5(nj+#VXaa6csXZNevY`iAK&vBXQ{e$k_IuC_zWY3tZOxorL)7vjox?pqS zhLxXY*z&l5>Pz*nx8Dd%xb5t;@4T|e@4>d!)@K60?=z25l9fN*$rqQM&gg#bfJAWW z)FR^o`Wo?7AKTf+RAho}M*`O{hCEcRmHI~b6F;N;Z@lGp+!z{Pt#hohe(Tq&i7mJD%jC^pr9bH{rq>*L_D9*#j#~s`&G9Mkagdx~T7?p=6bl@6uAj%EmUOi~44?~Dzh*lE- zK5|wAT3M)1ckYC{q23usk*ynUdhr}1qm2| z9!#NLgi0bX76buRxF{C^*Y5K4ch8S8^>s|dn zR%S^(c+%FU*t}&^4pP{Whga@r_PlN!Z+X9dQLt`c|00H`Q$Lf#etzmZS;cv3bcx~b zj@J4<=azUnsmJ>?hl)?xEn6$&Sj(4VFR>-q|H|GiciT^@ZH*hxXKdbbcT7bzUul}J zmEp9i=jh$Od7*dgLryBUuAj5s$hvZTQ|)4ApRqRr49n!(zba&%)S#tzC`jV!_F7gp z;B<~g?%1-j$R4LpL+Ghu;M8U^-Ult$bKCs(epLt&VD>R^`prtLw;zRuzWw|;_jPDY zBmiYdf@=0lygYXNc+;ax8j4CvbX4MKN3Q*CXbyl+QQ4%|U-?Fkt8&e)Qbho$YcHQ#;gY;!(#}a25UhgGkH0$l?uq}Li zMemh=e|UsW4=n{kqAI(}G%p5SA*d5@Nmjedn_-TIx$i|#Xy|=Xsd)yW37Ks~ zk%8Q9pE-C76U_WBF#7~b+cr6rJr(gh+Y-a0oU^Asd%4P4o2*6$gd)Q)jp{$rkkGJ2 z@?5pM)$H*nvCiR*;byT>{*QQGgH=8^Z}KczWcQY`b=*2As5sV!9Ix?vCf5+n#kb*kA+(9g^F6yKuQio69=J0+aCnd6zCf9{Q_knQo0x1?-j>Br9@d44 zFEAp}SAhwpadK-B?tS6$A2D)t4&}1b_kv445*l&J(-#_s3O-3oDoXl~LnfkY5TH z|1m4aixz%8X#a=Oyuc za0HZqBSglg;E$*qe}K&R0FokU1cniPK^T0Unkol#dZN`p)-&9;Cw$@NKXVyg=kxc| z<+KSy>juU{!Z`PS+*dPc@QEQyW00jMzmr*3NSjCp7af7_Gxz>JL&o z2AAN)#2vXwgmijPQxSV^;7pG&;fT1tf}X*R_AEMJQ=csso^fQ+n}26Adn9A<$Kmjx zB->s!st__0jqu=R`2J%$;~7kdxHkEV2eP~;->>{#E3ABoc7H#AY$bY35M9zjv~4j8 zlF4F9{QBgjgJKvAjbHp7!={3C3Wg<|#b~_7b6RTOfwCl+|8m3` zfRsiAivX=g8nrm{>_=l|8V_-aj&G`1H6Y&%$9Iyux-z-7`SS`%N#(x|mEocgf!#)C z#J6u(vD7!34UdV5dGqmCTl5n9JDdJqs;ZQFxFnuujd({(i%!gf5@U7EjMr81wanLy3Ci4gLh>)6C;v7TnbPV%i zgHXQg5)<2b|BC83MiJ~iwi5v!U-TnfvT4ni`o{B1zvfbPHk7W|cqTV&crt8+Zf13_ z;^{C=_5w3p+t9WTU9W6}jm*q6zw~=~cu*AXOMGFI3fGc9NXfG_nMv7pg^C=3G(B%6f9AiR;;)JKMlnIIHqT(&wryVq2E1?v@fJ9v#{XNV8=i>}$zI zu})E-H!@~x|6M6lJ_}TE+@eUyYcN^HzYvXOKrg`k#T4Pe15$VpJA_dau(W_)(dFJf zrJL!0$FKj;`*-j>X6}*|4=w@af{I6}4jMrk3Ns5z0n*Y?1@qHFL^lNMbzX6CEuABf zk&;Pfc+FFEL5RUSdKRydR{2An4hh?~kCqF|lh;A>KLN0U*a~Z9*~C<&!r40Dt7-tw z^I4=n7#E10JxiGQg|)S7h=np#4^aHaT)+5tO0=pGB-;y1--G`itI#P!^vL!FtXR47 zR91@cZg^E-tsB7+eX2=(;bS9tkMNN-P5<{I{`uWOy8NFw`OhwwlN+0djO6bd|MUCU z-!S-hos-{tz5l(#|GZ7@`u~K9zwZ+*_K&3G-#7Z-9s|+WuTGDdlgY8^>FL0|Wq>z9 z$~RqPwU+f35<5&N<-#NfGUXk~vnnPz5WCE+$HVKjwxUY|Q_`C#JYhTu_4eZHqg7%a zzP?Rb`k1BzA6VSNJHf&0HgIy*pcecLVuzj;2Wv8*aaoAs`TV{DnYI{h_EiW4@D>G- z{B=ggo#)R5boKPe3nU%J6k;~aFq!kZF4U7Sz8BxWAIZ}OG}|W!QmXglHx0;_G$F72 zg-HV>&E(_b!(ar=Q*>(_SzMzDM`Jn|dTMEFziDaVAbcFaYa`juHN19x*tDh8e(WXw zbQ1m$m1`I#pKTl#M|aTQ7XyL$0jRkK1O$*N(3p8qPHx^)(|Z`--cU-9G;y<2Dj-Rgsh&x3Qxf%RuNOUjZ93oAlum%U?MHp z?q620a^=ce;s(&9mCij~fmr~HDKt9xp>dWwapFD@J{FN+fZaS?3rCQH2%DB*lrgqPGRHr3mD7Dr@c$<_JavfI;Jbiq~1}J*-<{(*h zA@dOIef{!vij9)P(cRF~mTRw0G2utpWM6YVcrid*bL)o??Qy(>0h-+yNQX7@Bg-%2Y=FPpgtQUnG(|)@MJB)@i4ruyvkOHty71E{cHk-=-rgsJfl>dVAy)h=UFJM& zHJO~;%6=@IW$W=sGX8lS&e*|2verJMC=0G%1{P1PGpN^wEi^-5V4&JM!S;<(*F8We zf7{%=?&r^+RPvB?aGX%(7FfrlpZ)UX%U!#6y-SkJ>%b_D8q5JE3O6`dBsSWRi-tQ1 z4TVJJ*RNkPV7iiWoXmnEaR{XRnB;*&(cXD^2gt-V;{A=>h@H$R_u;$;H>ooS0PMoX zi?FvSJeHpfk+05XtQ_B7yBK$$wM(*MJC+q2zKirOY_SOOGJE&M`A4P=krpkX49}d2 zvd&LSOVe<_`wB%Np}oUj2dQ4{*p?Scf68@tWOVn=iDDal&c4D(u!_mI@x|B~Af0TW zNRNSZCBQ9~MH|5yEkZK)5>*_TX1fdN8*+n2QqU2H2^8#;gBcysy*Oi!U`55`NRUhj zku*8}ZT{hG5+(7^H2$%FBknuGh>O2iBJrV%(??Fj5X9Z-i6LL&ISmjF^L#HFiFvlu z5rs2k1;k7Q_XS1cj|0&s_$dnM%)%sw++0I*b59)PU+|pWP>CUvWy}f~lX@g+FG89p zkUwn1*j`y(-3??S3X}*WR|OT?txXOY25p(R?-;q7yI9n0rS@`Dcb3Ulh*ez3PJeF? znQY^>15T5;wk0b!cwjHA$5?jNEMxkQiGR-S2x=V}b4JyD$j}f2vn7AEUG3=B3ZL`X zqndvM-cXPnr&i?Rsw1EhBa(Xk0_Q3zIh{ew#e>TVA|(^i$ohj`ja7E;bEYt%TWs32 z>H6;mO=?n4E?dYW|Dhh|-)CL1^1pW5e?4o6O7qLDE_^|xz$7aD{>+jwBXxcj%t)n( z8X1X<5u(cd%991oHXNt-^=H zUSXxeC8j~d1#&MODLU* zWeSd!G^yEi48FiY2DaX9_!$hMu4GpXyO8Lb?dUznR)gl zCNAzCTCFrPWY!Pw-o1NUM_}~NL=_`*JQ)i&u~QTiQcv|@XAG~UkwGzb7t{jtX@t1N z6kN3L-&lX>uut+uHyq>2#gk6L!bLSZN$wU=h$SUAlDj`t_k@ ziP5XpVSi#*ar zwpS@Rgt0H!n~QK>5k4}CA`273U_=v%ZqU@+PB zu^Rbn^Rp5#Be+ARb^z5Unn7J%-Aico!DxknLehi!-v@f7uYfc!+1SWp*%+Nae~&19 zP$%r#z5CF)bNrAL(7fhiy$XGjxZMCd^#x{KgmQ$Ndjf;PDH(ikQan(L>om!f5|nmD zH8rcTiYbbWNx>ZPVghMGf0f&>Cp#p8&4AnM5QN)mg_V_y6e;VC`0=HS!#kr_wWDi! z7>sT=1a4dab9;NcF`RhZ+O`QUst{sMnW=POeFx@937f%n9#ofGiBTKKd1Nqf%;5xPC5v3U-<9z$WO)q zFyf~bjFdTO_kZw#f9e+pKQ=qMe~XTV zKYz^qYxVkX0rUS?`9iAj|N8Y0iuIp5O3nfK>Hl;k|Lr+uO!3Sp&0|I z=z+9%tqh+Zrgr{$y#2}FZ+57V-4_TO>q7B0cFfJ)9i!f#{eRBb<6fn}=OVhS*~O0Vh!)D)SwJCg_U(2PCs6a_lT01qZ(=EnPy&8l+As zi5XY=PhE3#Bn1)zBfIz|Za9O)p0WY~;73C|8+iE3VHNl3)2FA<9Dd5Qj*<#Z#N1PE zm_=cI64m`%yNE2}+O^CNl}xfv0f~X@AcTN%w_5to2Dn=htuKuEJdoLch0I$c zJ_F`2NW_;Syc_}&M#VotDrr?6gV!lC0)_O1eqXa8npxEf}o#j z)y9E4_e?!*B^1f7$k&kj(V=Bn1W^Q$P^HiB%%7yi_98>ovBQWkT^UAp7*|#VNN=ot z$HHwTI4^Dc<37dBcP#9gwFL6CJc^qzOd?M3NB06QFRH6k zC*>Yukhz1Q;ZuqmMSlsJV9~uuR!J#GF=3?dNAvZ6HvxN>iJ=#8eEYw!q>VXF2g(I# zc;V;Ix*wHN)W9sE0h1;O~h$EJw2F2cPcdQI$tS>E%iN6@5ot(#KdeN{#F=64@86w z-*AB+zJWNVz!ie@%%A{#QSUhN&GW-x)eB_wWCRZx2ZQY#37sCH^Vx?!24(YY zH0;h`?z_b(xQsLu?(m7wICJSe7WMQvu8A^uN4o36eZRZA^x8W1F7yuc7iJIoE8*Fp zMS>jSM5t!+*>c_OacIVKe3#h2U&+K)?~cthTA&n#o)$UM!AJkv)pZ+py%6jpltU!p zOouxTR{rDG&<1!8ed+7F3xvFjbO1rNT!ebu7TP16#SNJ^lJENXETKRVV1;6_L1za< z9hM=nbB2b7z}p?ox&C7Q-LAs=OG>X{*|RP5ExI43rR_!j4hMi`=l}%Rqm5bFU=uEV zwnzAt>+UBS4*kpTqub>Q@lT8SM9jhFC3u>Tv|HtrF!WF)_e0s0&o;}56%sY7M z9GLg^24`my#>5a;9x^o+H`hkj-DZY(NkWhD&0-Fpb2~INv?3L!ajwH}LCvj;7w+P- zYiGQIZVGP%348Gq0*?DD;VfN3A+LR2wcmLb0qgncj2-RyPWRw!`Mip8#frrgfrBjD zv8EbgNw$t~0H#@Bc&|0G97+b|nXS$deR%(I01qgFUXlcUjNdlkXpx_q^~DvWIiN9d z2e$$}Q*;OGbraR=M1c%#!ft?Iyy1+{--{DB_KNdaxVW_ohf>xUOciM<_=P5n$zMWek$6A5wCl<^jr~I}nFq;|ujAw6B+2B3HX!i>0vVEhm)9vNet~&dDgh=Y zOJv|w-JLmy$Z3qSfnc5Z=BLPcfbhxy3&mk98AdnR2(@`2zD2kU@}U&1OXk&~@IBJ# z=n%lY2!Ru{407WLq=~V7=abHsuYnW8&VUumm$OjCILC?s7Ln1i0!%kzShnm9B*w3=4I2ZrLFaw+U>}^0NOg@8F9Nj2?yXuoRq^zteu@eHR&Dm;Mi>kx@2x$Pw+yW_w3&ric@ zk8Z!+fD96%FWr>VA~t%Y?F6qm7ug8@fgYyb zvqi4UNOuCMyO9{NV!!(UNdWR{02~ph7fM*V#2mPUoueY%m!YeE4zMYH9RDM>xC}WG zB|4~Mv5ttG02F|{Lqew#WKf=iEqCIRO1U4>Q>6~h;6JD(iMVIcqD90qR_FF^B&^=Q zQ|9x4%fE;^!4Frv9mI(P>7*=*!Dt&~tot(N5F*(zJP#6ILn9-a^XJbKt7#ynWk{Y! zGR^jcf-Qu#xyz)2_ir~fwXHLy-q? zPO=ruu&o-kV_vOVQ+(~$+?7A`PMB(NW`1@GJ!V~_&!@%N3%XZ!Wqdv}8|yQ?p4IV| z$K3#&x?ejxc`15|n7W`;VzEoPn}~KJ0|hnniSxn8bQmebvsZ;JC;jt$XWk;TQA`M; zI|pLF9bL$SAr`^G;XC&I>sOwHgoNm!030YDC=Al3E9QuhE?`>w54<>|_kk{6Ym+k1 zxzAQ1&hCz@D?I`sjSM20@ak4A}c1I(FC%2`*?!h-4Lq$NPpbhy}c{RbN<9L|3 z@Q}Aq5hwwhDe~IHxij!z(COc)m#qO>2f9My;vRSvBOWb{%%|pNPqakOC19?5w9|4K ztigjMM*#t285IfAcPCm=7-5g&)GarVaq>xrbiQ`2=LCM@gAM z(F68Z#mdI)#~qeu5dVsS(d}(iId*I<;79-TayX%7=HxI2EAvqvKz%$xp7g%5RV8FI zv4Zso1kRy;4sq}f_{d;=&0EwNtv|^@E*tw3gOKfhJvwF#- zw*oy{oOgo}BGE2F;eZq8d*4&v5iVw*tO==~zN@x>efGS7Z_5@>Sg|QAC9`^Sbk=n{ zzf<}68%2R?+%Zl9`H)k_8eLF<$@=b`^GGgG!lV9C^gvd3H#K!nz!VTJ7UD3Gbs(EV z>Ut(>pQ-Mk>)QI(j9R}XJnbOJ0oRn9@_<58@r@QrG-A?$hFV{r;KPRx>-%=V^NL~u z;BAcdYV8<~1rpwAWRTnJZ?aT8emtm86Ad#EP4z-1z>BXztCNO|{gh3#0)owt$)6NA zz}(G{*Ke!Jx%}OSlxtKm;H|wcXV5!grCbve6T!xKy=M-mcO;T{S{Fj%+uP_g_TvFk z{78VX9LJ?rgAj3V_{+@Tp=07|^H}WrOD|3v&5lcZb$E#2EYf%nVJCV>0LfPa9A>5D z&?#fg(JnzUxo}TA?_bh*VNj4yaZ6!5Vw(N*qJ)qi9Is1@i|LWWliL8(wg+)ndXC`2 zu^HC@i-mxYZ7NxLNrHs>eC*WZ5mD$aQ30mQVbNG4$kRHd(rm$EusO zqZV_8F$l#?^sH&wuvS@3xdsOZ6GMX1^bFip>cp}yo*qBu$`A_BSw|~OtZkQb9amSz zAEVhL{x_>HZ5BL&SQl-^2e|@qE}&imT!0he3;)NEa03qlhMTNsiqaTyAC?HwqC3$> zLc?e~3bn-mO9>+q*(f#w`W9iO-2!FT3U2(7B3824ToC~5&~PHro06fwi95~1&rgpu z-YqI>6X~YMDzCHcZ(M<|UX@;zaj~Nm1(l%ps-d19cS3hS=hUI}rj4F9Vmb%GR3>eygS zsD_vkzz zERya~hR~&0jH)OrM-N3H;=Fhd%VTw81uywF%KrE& zMD;zO+|W>{K{86Saq^LjU#E>6pjA~Xz{>>utR9>JSdQ3hp;-cF;FIy|F$o083TOzc zack;6k?}Tk6aa6(x2M)1HryIEbdOxKN)Y|})hMXkA?iUofoFdQ_kav}`IKv)qVWOX zC}79M6huuj+nwJ}CuE|MAY2hJA5a?HMlpmsWEiZBuju!3v_6#Rfi!V7X~AK3isBV) z-}3?AfA`8ZWooY5Z@w?T>hnu<>`7}9$&DLGBlx&joA<#W+@Um!1Ox;EmVEicEy8Et zG2sy_o?l;AS4Zl1!W{;g0X)AWNM?Q^Rq8~%$!gI@U^_EgI;rA094c6PaZF#J`0Uoz zth>4!MDY=<;9#Sbb5dMF!hC~N7cLS(e;N22@GIQAV+dzo(fjwx6cd;(8o?Q;bI9(} zYwN;Mgaj%2>l*yNX9wG#8(V5X3vopTg^mUb0vH?cC{|P@vh+OUR8Oksr+z-E-R(5-@dCeq0>LT_i?YnF4vUsa#p?H`=psjkZWuL$1!*QrV zu_#Ya>#EI%$HrJ18ynTK5M_?G5`e$2>D5j>l_fH%7n;>9!=_LR9UV<8|EbuB16$$D zm(yBW#K6=1G4Rwn2Teog0bevDJ&$;EAVGi~T_pZr?a3fya>R_<;>zW0CDpol^i;yg z6^0)q=x@M;I6uEae|}}Iu9NgVj7JD)IeY%RkVVtW?Fci`)rxpzHo7-VCG>bgA3o%T zYXScX&xa=u8u0Y@_k(1EdF!62;)O9us0McSATBa4oY+{)l4*xR!M|BXbrLT6#rocGQ{~1H{Cm01s>-3)6f2drAP`M_qdy8AbHm8B86f>3&8+OlHLQ6qpX_}S{02Zp)@A&J>%~BC87Fw5G82*hlSZmvT zLJjN#omSQj8@4&I^$=csjG-oBNH|;jzf@c4D zV>YsLWC5sc_k6&do`f;|v+V{8^f5`T)XsxEuI$pzxv5oX29ReJi6^HvdtYM5at4Mq zh%9SddcMb)Y&0DKNVD*F7c@|IlN^q8iYR*Eq26V+6>n zHi8Izx2)g$v+CzD0h1gQ1BXG6bA>-5nbN|4Iz6Ex$wE!-A+!>#IMm7GK~15uZ{!Ao}(wJYf!Y({uvyTW#5r zx|-S}wOrr6kbIf^={}(cj~?+?sq^K6^-+Rs)z)C@E+qGA>& zFay+0G-tgs!l!tu0gkOfd*I>9+~kcX3J{GLOfjS*Qf&eOAQFgKEsea@$`3ZXqhH`1 zKx$*d-U{de?spHuN_y@to@VV;J^qzqNGwo7umM6k)yyS(cDY!@hHN7sz=Q}Om~yd0 zs#op}njkNe2XZ%Wcp_=fyL|%Y_$PZeL19j6q*}*_?Oz`S%@ODgRHb6%pV9V+BWuuW z^gFgT7X@F;OQ~h@di_0ePpXY|bV6`3LAUNF%#ENBAV6T2#hrAQy*qXsLlirLF(%s7 zx>b(Vv$?b}x0Y+MGPv2wn&5D#EssdJfxD;Ldt)jBVTsm-w2w-s(5_CQxPi=-AP3jz zfok9`?!Y}bQ^YGbnJ9##=vs!lzzC;qPxX^mHnQC~yNW6+l_(~t!Ze--West_!WGDt zYTr%`+7&a1)Fd)0ek6O$uKOR=2jE9?a46_~Wz$ zAvA4C8tH!rUIzWh+yUw8<>g0kP;I9?kO}waZkMGW{@rBo{hE!!ROmU>WFr!TjftTx zo}~%?TcwiImF#PkPwx^e{zH2ONR{^jgRH}yjH=LE38Jes`>M5x*M1>`CDj%?2M76v z@tkkT{X#DEOu5cvvIsC}FMA@u-cHgW%Oy=$kn?&U9q9jfy`t+12q9@*FVQa|q(KBv z=venzMm^kN{c6WDdHf~8H3RH~W+ITqf@iuKM2${>LAA3$x0RCOLGetSDU%JcuLCjI&j8 z^JY1U3ERFlQ@Q)icZiELVuiSVO2Svj`!|GcF}KSSG>ZgV*T9dX+LVFEb@J;Su0<58 z=j8#t$hG1qZqkKTR`$}#L%MpjzuMu%GhI{cggq~J1Hr3LJdGX?a*n#J#H^#@v`pZ3 zp9m9+cxE{k$|hq8eg1L>be=8%!Ojp6DO=f-RQ$hVY`fE(`lxF z7)G0%R}$~Ja$^>c&EiTIPSJRIRA2ETuXZu}OkZCMZSSJ{)+QI6RhOu=%r9TQy!f@z zTa9BD=Pb{JbedVCG5s zcOT*o3ep~|**;X(Z!!EmJPfWUTC96M+`!%kCzc7F4#BNIL?-ks54IxMMsx4q-R<~b z@4;8Ekb2)GL`y9HyWs74U?B_x-X;}g?;n%s!ksdQrq+O>I^!WaLZp-iZ4>7$J zj4K?*enh1G1Qd*KBPL@-$d=8{OQr2cI$%({K)OSOS_g3^VOl{DzRkPF8HBcg5h%u&P^4Z$AJ(Gsd|Nh| zEcD7+uVT1iulK4|gt-UzI`#Bce@*l_Npao~8wBzQ4ql4;fEaf?PL-}GU#VL0i&$us z3wpG&U8C$Xzx$U0qK}S^wQv$Owi4*Lhwc3qn4P6Fjt|sC$$Cfh$K`S6eTpoJZf_uLcEoY3)!!qGu^*#IPn^B(R4a=q~)1ffxryjKPV=&B1+Pm!i@ zRHE3R-d$@v(qq?i7zYBB8ioK2e}7E?!DK1iJfZ4Ts%EX~ySFp{o!vAJo?)UE+Q*oirlYVCx`YgOW@UNlo?zboP4x;9TrG zDR%jQp6$RV)fRc=BbxS0PT<4I+OAzHvQrvI9Lgpf_D9f8vj*&D-i~V@Mw$*h?c4=T zXG_14(Pu-upL=*b`(Y@lHP(5sTet`b6qzpexbKH-b0f06H%JQ31fxAl`?0+dXQ;wp zrLz6NDwbopAUE#Bc6p(Kdm%+V#i_B-a~neSWQUm??7xHSdYrA1VzZz zEE=|y48ya%c#JDSNPQ-<9X(J6WztHY#j51!7v$1Ov z-_N%`AqFxy6lf?C$hAt}y`v{G8XQ&#^AYlu4c%kQquTzJ-J~lg3?!Gxx6ue9qYemJ5GNo=Rf`V{f26;`d7b_M?2%V2653{!)TNql*aqjp(nj zKj;>mo>A5p0Z=A`0#LWCfz8J)kd#XOH0)I|HieLKiLn$qXe3;ihmXu1L*Ubq24FJX z0f_Kbc5r_$sr~=V&+H`iICv!y0A$chx(Ve0{%tXLSB3~EF*2oQdgWk%5#_hk)Hh}b z;u%0Emc^Ak0Njj_g9>puaF@3zRWf&T|xK8g+t%O5LK>sI~SVetaK>q-=b1P#^nVF z;YHODJb8(iC@~B|W9t6P_lrbmYXLKK!d>VMl9=c(>gzt6tk60GwCfKzl{Obb(}gYr zX-jQH;k^=avK83rgn~po{K3FLaj+7cQlx&&cosnM)YX=-VT(h?ZXPbePWQAfATBHK zmZE0!Aaj;vm-`hSxtD#W$KIK+6^ifQA0JT7=+trV{YbL3V&JytX)kT`oh{Ehmk*k) zrU=atN%l6~OeSOu#RyQs;4T21A|*aLmSmRQXri}D$euTth;a*HfS9HbL4Q=#nf&e2 zV!*YK*Si3=PrUi;ua3m}ui5cVgT~M)h7OdH00==cSp*uT8!|`kkAR8y z1wd-UwTV}v6+`XczI~H)ny{F;K$L{-?Hdui)Tb^i-yqI}8gIcRJxu&2;Jy_FbK~oA z)bfM+O@ltm`0p9>rjOur-QL27nr|GrnP(TQHjd>Tc7sgDV7e1C$a>kUAUH*Ia4sKHA>AF>8+ zPq{=y@DK(SFsVp%LHF%kHQC{K3BIv;tvVWATP){f?FWMDW)p@ZYz$B|D-BMH#D5Mi zG@A|F@70l9MyfJ&tPCzGoQyPf8>i*!b+zNYC?NiD6AS6A$wOCVAli;f0$B19lkx zi?w-(=ey`yvCAv9S>W zk;Vs|k@Wzs0{R5+h*tKCn#3&EO4;8FYRh4+u7eHftTER<;;iYoKDxrHJ=YVcv=R0q zkzg5`n79E6i~vxRwXsOGxgFY_o==VM56zpz9?u<|nJ=<#96S z)@xu7U~J!HlAY$Mx>zC93?Ovx2Wjh0*V5h!uI_=A;d`o5l9C?0w5hc6IPP?A{hHLL z*d_{BMJ*f)&@haS(AaHY{%9(%$?ksqhWSnLVeR*TUnX7|JtQaR+DLWd%tR`2GRDU7 za!#Vmd~S8t+&0thA3r>RmJ(JXAzl%t6XFRh(nYH+E0w2|X}M=y61n+b#pYs8LFtq% zeecZX(Mw0{V>XB!4H}XGbD;i}0aw%!Z$j^h5fTd8rX6Xcwkb~P8!X?bYqOI{6#-AP zVg~+b!&{CHY9^{)dc7j@)pK>RD`Jm7@qx1vL6||asJ*m%>fpNNrgU(*j|Yf)b=de0pZ+QAjCND;r1xD0 zl&4ach^q@yV89kcN(7pPMwbB$t0_Zxnt=^CG@jKUgo?9OM_5kynp*wQ-IRxkYhxN-T(`|HbpME4>U`P%wC#Rm`6N&1Y-Znz-D6ldJd5Pe@Xtd6GpY26NR`9*i11t7+$d(2n?vv5rgpjsQv@ z2_YuvjF0wTwR+hUp9)v&3vrm)dwzo{zYyv-AYYWs{O_bspR_Pkhn@Ne2b=(V;IZd0 z1&{gy&W&b#Xjs^`wiV=|1=xccT6@{a|Lro9wQvCfnePyv-_e4_Qzl4hGHtpUP~T%B z^B>&?LYjlRfFx9)0v|NdN0o1AY8p&OXmeg8?%_vc_n2q z2u9WtNQ8?jDikRu7$dC4sUSFt>pXn!+%pJjToyK&;>JY7mm*G}dMuE6 z>vbjo%+*S*hwv1bMNki?_$AV+Uq@!L9`DP+T`)#hk92(K9}KsRL)lG_ z%Y`6PFXz;Y36KG55>JZ}ZgiUHV<3B^2wD*JR65@=sU?2}&W!n-B3O*8(L`mUh&q^{ zk~(duixmbj&p81I%WAat0Pzq(FGY`37TlKTQ&*vc2Hp>~G#JqLW2LW8o&aeBj|<1> zzKimv@cLm!L3Ltk17RwFU&nylB5f`NrVOFQNYzfpmXwru5)dBT34|$v^6qJz>x+N2 zh=R^VXeF1_HkE?wQ^_zJ3`>2Is=ARG&v1uWs1lCvhMw>%2>b`xUjd=~?)`Tw0?zJUmRm zyV`??QSN~7|3tC1<4Tb#Jr8LTaEfc|=rEv5Mzp-Lk9h118esJ&f(-^WDK zCIW8;x0u*)*V6dL{=V3pd>BZwg{~Qvs2j$mQ`|@!W7Kv^kIMt4ltYfNh$3T4RSTS* zdW&7Nva%W{eqr!g;P_DlR#(7zgjoT4Rj04?{Kt?>aLtd#O;I}ld5_<;7#rmp6%{r4Bg&WDIUS}nm@)LuyaR+r9vpdqU<8<}+%;RNbz`cV z4Nr?-RrZ>_{n3M4DPiMc`Tz&Qiju^JKr~u`uGn_4cez!_6mPgse6%kR*eORkX|T{@ z!M;2()t-iK3xZm`&L2qY{iM!g!H{FL2#z4rK4{B_l{%o&z+gc`kKCOwfXtuRU%U3~ zG4(NE#IX$5PA-gq+53#nLV1OG99I*s1U!*|H_QUP2{f+X%z(tw{9J>fNcTj=aJ2Os zX$!Zt#mv>AL7BK!8abwGN=gbIN%Q9?p3(Z zSJcTvVxv@M01zdhvkZlak_icreT|7i3ygfSB+2qAYHU0e624boZR4xsF!uzb>zt%B zXr;*>US9R)>jP{(9T|nbvNJPp0zS+$$PD=4Yj^h+gzdNy{U4#Wq#?dgwE_GJ1Knqg zroW@0fD(DRye=RkBZK0GZB>HE9HOq84ZRB}itntop?qTj&b^kzJ!GS#4FahJp+Vof zm_K!du-^3!#45Q6$Q^mz1Re;9Nx(+sAebn@$}a{Saamb`NYg=gIxW6$9|Hn3p|_Js z+V_AU-UY75KpBIJ$I%aI{5>EO+C)SN3Kj{{#JK zh)MPhVBr1AFIUPMK>DGip$4W^QEBOsJQ`>BB+8a>MZ3XnU(xyv+TMaz$AZDj=oCcxXwXc?>p zc*hJpp$PkdycEI;4^%qc(^g8|Z!R}VgO2c6oRB|@LB-kI*d%6nvaz!Fc4f40n~+H@;|TsG8UVA&23hfKOW40+>Wr0;|`O*q};li?4f zlR+l6MB+#$)(3ruLXL$H|AXr@gBgT@PBJOtYXN50FzgvUK|OjA=Cwxl+3z+qR~bMD zBD4t7DaGQ^=``^0dfEed#TrD)g$^^A;|t3LTtk2@9yr@9R(PaNL1Iq)3JwFjbIr@k z+l8qKWMQE%2q}~*psHbCv;I3yN3X=I>9wfU&3?yV!()3=<$(dL;ZW5sW`LxzDc^;9Ug4jtn4kJq^ zQHQu|pj0B1RYKSyK;LTM#^v2>%@&Zt;nU;n1#_#q#HN|vxqkdrxPv#+_yxjPrrjV1 z5%ZuSwbt(P-9*4>gusjiLo$Q0+`G99z?h^53BgzbggH@yG5-|D<- z^p9q|00?aPAMCw(SkHU+|NAl9hHW0AZLGF3L`8*S8!LrUA{2^LLK&LOwz&)q5@kry zV9JoGK^vt(h-6G8q)3#Z(0Q!b_xC>6?_6j2@At>)x-Yk#`h4D>*RY1?damdC@XrDT zVSB4CUACd-o^tC}OAd)f>C3h*T!FD11bh>&!0(?0zndRd>eNXh1S1m@={2&vf1`WO z1Hy2!v9I4GCI0BEislyML8H8Y_#NB6*6Rl8jE5zcqtvf zasBsLA1%#K?93*NB)-%!YrF=KP;}%Kt9fzCKPOM_F$(;XyM2wrA{5?UZMX0|JI9kV z4M4PmYnWKu{ux--u|?~h!hIAd3B2hzWPR9qdR zT>C@+e&hdAfaw2obHLmw_nRMBYktgTpp3sHTwDL7Ef7GD9^NgapAK`%`}Hb4yraGE z$J%%9rJFw|G!2nmksV=K+b`3u{r+o%lrINg+lSW5368o@-mkEc2xk?wJHcA4|JCZ} z-IX@FDD=nl(Y&bbH;^A3`oA_{{G12>v-^TS|F>^&uKg^vfBgSV9RG`S*0~mH&vr9V zT6zca@;V6llyU;WMgszD=_!VK;k*>q!~o=$esMfybd>JXYayZ zJ^ttG_K;60D$3~A3dI0~PU?4vaQ#KOj~v8!)WfmR0)NjFtUDMflP27`p@L9$);?(t^2 zF>2|5+MqV|o6Hhr=?wu>;<@jP^ie=K1;oqi&~t{%OkG3umOZ?3(q*HYn$P zA2e3*W+I}m5E)2?qI-QQ*C1&lw5G^!#Dk5gc+_%*>@d&@(H7Pyp^vvYUaSe(Imk#Q-u6 zZj>K=%^aoLB5_@Xr~(0Rt)Y97YKmnBv`&%g6>n18pABAn48bU{9a!O1}$6icS|_+?S5C*BUNbJtYd!#p0bgm*lVj*jOFd~ zzUVl~rpQP#C?TyR7fQu_1%c(fg?)@oV>@`Y%8ETUr)D zp+do~D0GI69B}&R%}EFou0DNQn=Rc~QBk3s`RdHJ$GgBB(nuiySuVhkp;8VEkx3qfOoRG9a4grwy+=M`Ufnm z^StSyC1UhmC_SF$lY=$wQYl;u<>vgwi*tRl!cRQzVqNDl5LbVqFVQlD@v8Bb9?wKkL$g$!Ybv71pm5%c1Bv)N@f% zY409AG(55KVxtFU`(V)M1o9;ioJGF`>N;ztANMx9#o4_DOF9%Fb&mIvB|5~Ttk@ce z_0V$$;XIApRvqJT(FF8+jhOE z1$cUU|3il%^c3OEM0)Z0$MRwW++P?u*(_f^6r*gl;Ys}2v!l_Z0DD+3O}rRG@3X2T zR!K=o2%1Lf+(=LVP5SZU;s*Rme`0f89>pIsxm4RULc>?DMCKAFWqd0gzlFHmc+Qa1 z_e2LZLle(W#xepml%zfrSmmSgLCFvkA4;1A{wwiKZ7*-pwMGGp5SBM!wAxxY0X>Hf zRp#g+aJz;fF5aI?#oNaWt+_>S2hTW$ULrW6o1plipsXZv<8-<-hM#kG>Hn>ssp{De zbbGB?x2_(5^9JA(%0dT$%pi{s1w}eIIf>+uHgXkFaAR3T#hL4W4eev}9DE!h1g+gY7NfrJV&nNB+WNt!_QWUSA(Z{PZJW>E4S z=Q`CZdfcdM(=dU;2_t|K1{R92cdN7iIa*ce6S-jlNC6t9rF)05g59Sf~k^NDmJuosO7u4-%Ve$0@GpqaM;K zosVHT$=eS-yYcl%cW8HffTAh=&W+S|yvJ-|R49g7zs=EorYFP&%U@v{0mJAJqJeu}koGOST*0iqckR?k zL$z}qBae~X->#tUkcfJ+r?)^bfxKSCm_fxE^|EB!MGQFYN=>QO02nC$qOK8dA_Y}^ zCQB?-pT(&74?gz_We0Pwu*z(s9##9pb#xZbWxF+Q4!mjV<#a^1d*>+Z{6RrdURO#eTIdbWh4{})`5j0zY*&p4|h(82g9KMD!>n$GtP(xk_@8&nzDQsm}Pu#P9t=|`uc#e(4_GjlTj3Z{&{46 zFf~dSyo7}}4V=c(FM1?iNQxYPf7`rc(>h=z2j{A4jNkJ7R#wOg^vEt;v_Hzeh(*D< z^?b7PE6h_w91DqQ7+AIDJG+VA+1ll?OS;{eyW{Oi4c{48W9Wm?4koh<-+EAU%?o{!^iAf z+={<=Tf;Yzysh{w^e55sO-&;jhVa4%jbhl@RlPW!1EHT2c6(|8Y<91yftL1 zVPe-0)#qbQ&Pi(ONeL4n#Vag=j0_YW0w*5n7TfOeg2HUt(y3PSxk);QMmjuk$|iTn zBiBl2&53}3fZR7XXwnlFDb5SK(%p>0-0xC@{#bRo-y`5QJiQ!qnt0dH<7NC&o)cT+ zJYyHZFAV<7n? zwBuJ7`~V2k4*a!F6t#T$cVT~sHs>Ih&XfCJdYkStBm7$h%z4r8qAe5c?yyuX&~;H? z6SpMY{lv6=3MO0pb^T^xO!xsLSaND)N~nYed|nyVh)x)ZOvi!Fmj79h-=j@%PmL)U ztR?X(Y5@(lq8u3cRpwbp)6JyLr}|x@gQYg! z;L#g?dAWC?<#-io&@V~93;y>^x;;w34W&Rh_-rA6JAMo%DKER%u_Uk8?&Sht&FYw8qGb8-5P116G#EnR_gl47 zg;>S^K^(_M->O3qjT+DD7RWHt0Auq&5R!J;q{{uh?nnO_~Yw^(xW%W-dSp7 zxqgI1UM44ymlRHZ*qheo8@_ef^A#N0rYz!a8ci0rj!fu^iVJ@Ccl}+%;Pgtfb@5rm zz^gvJjlFa6gGF$!bAPwYxnB3-hDm=QW-3u3o0CH9CAlmO7AL(coh2+Mpx+P_WVNjX zuf8+2nDXym0$3oR@m)7!H(IC~H6@oc!@|O>Y=bH0&Y3z@^Xir@TXHV;D;f}P0Y;lG z@)3R+CBjBOYerRi47_))&<1wz872k~QE!Kv>G_r4z2`r<*pBdb!=+0;$Ac$GD!Cp3 zC9{%i{$AH=CnPkv1^z9#MDZwp+E5wUd2|5zHX-EBjFO|q$w&W;AgoxjVBE^S)hq|Y zMS^UrtEys}HQgHqr?>@qb&&0=m-qmUZwEu_h)e|pIPLNK4gDas%7nrD$*YDT0M;|8 z)rf+<3zO2_rV=O#cpvd>#<8PEPeMrya-5vB_2Q^1(YoaMrSJNfM?(}IS`m}Vpgl@1 zQo9uGDje!HNOvc(iggzZOzb_itRmA}xe^l{lt=Ni?4q$3sm%uFwLAkfX5LII=|Pu( znVlG(T9!%VCQve8mX_~jf6s#aLM^;}+s0@{Hu z=fK3H=UfBx2N;{6ypa8FFR4EJcbpFaNs7p^`Sjzh-M_NxH%=J4FlpRihZ*tyNglc{ z&-OESvu?X#Zo}2&60v435YRD)>wonoD;5=|HW!gSP_sb(DWV#(z*FrZ>N5*_(G`R*NcscwIL-To4YRW5%pEF%Z3P z+i|tAL3T7)AdFW>fezGY%2*1A>BjGhi>0SlGzN(`gWgCkN&)K51v8{@cN?x4%v{7K z3ZbZp_zC5D0dydO?#z%$W$4R0XOfYq{`gV7FlZ11Zj6sq%6JWG;DNTfPC51Ax81q$ zeblw7=1P0?hY$eO>Ok{=PBxa7fvJBNN+3rUnA5zGREw#6Ogr-TEZaA&i+dBFv%wpV z9qWtx*i}-Map@6b$Uj_i3S7s`5dbZt6SQJk{#bbbTmTjINoW%?7u5k zgj8l%kS#==ae=R+G4ebr8KWO=!+D;4`y|olTtyYZ++f$6WLBidzl$kh6_9UTi=V;x zSiXz7S@lyx^IC^mA+rYEXw6l9(Ac5bXm1oZ{EC9Cuhtum@ND~j(>GyTM}jT(>-RgC z$x5BEK3?B2mXta(tu9Ff43B9R;jUseiyR#qSXTkbh2d@HMKLQQiH_?rwlk;Xz=3YW zey`8GlSx`d1Ireuo-(Yb`=^)waVpz;?k!7XG#Kz-psk~s5loIQ&~R z9=0JC6IqA43R7{xm)ueT*7v2G89 zOO6@8g=h~1{VVW%>7t^Ao9@u2iwIQ_rL#F^3pCeW)3#4y{FyU%y{23c=IhkcW4zxo zln;!o#4OKI%eEI>5Gf+a^HBBUER-e)en){!6*d1`>?G0GtFLbA)P2K#;p$ zF5cLwFL2NrOulodA#kx`I2GlO_iZ$s$+t9|)l7GzFqXFDR{RkLVGLa9=ah_Kp8L!-$SOXhR_+ zsi$C0suHR4I)sD*4?v-hTXNg?l%4nklm?tL2a|GsKG!tbKBL{dd* z`8<6!f_9Tbh`cFPhs-<6Rv48@ut1ZmPnp(+pF>x=_9{9VOO7dG%mIF*$}WWK=`4Qf z@Pg!B8XNtJlD<4|D(dDMcZ6p<`5pD$P9>MlpMNoJhQ>bb$_A*iF{X^I5K4N5P2Q!7 zAgsM*68RZGeEe>XkweV!wX2=bR8gn(h?!*Giq!%skp1_4 z(Y}@WtfozxjKC2|?pFv_=J}~q2>`MH?}kf5K*i|o?r7Q>{UMZv(%tl*AU)vbS4Mb}+E81f5~HTo%H z*{apfPQTwU6)(%z#pl{TOQS}9rfQRXHuS804}L~1;eSzflhn<;RUoJ-F8m;I2SO=2Mp~ zT6FTBH={d8K`NcdU>DSr4xt+)h?&q$b`2x=KYd-$RMtpd3mc&onk(gr&rH| z-J|j7-0W*jobp-UkCo?PrY|b>s!q!UoLcR%_0NO(X zm_lZf()OtO;n??|i^`vY$(}L{%DmNF%H{-{2RS09a;lSv%`4csv{wra!_o-&}Rf2kgun5?QAk z@+p{28K^3|>uK8jre|UJ{lgts#(<+_-r`#ZoHbh{eGwk_pq&KKt=*Ho)~P@*4eJ zwI5G^@scH3<+F9<5{G-N@u*Ed{-Tok1!HUF@28N07`+P7rrL0DV{52Jh7$x4SItj4 z=ZY|1>$;Bh_ijI3APrkx#f8(v@&_kWwDM@1e=4swe7Ubi@L)$EpZ7W27(rdo1ioPTlc%?IVtQlhIb&V@*{9BHAq6s#JGHYCC+?Z!%JVvgGl5 zOA!`n$|`S^aY|-p%!*;IOA9g9oWe{iA2BDduharLw|!LT8tpp(zJKfpXTNuWSG=xW zy;?vIPd$YU5Ow{5V|x!LTU4f~5}Ksw21ada+PwLVW!G3z{qHmNf_aXOfVJg zT=bAOQt^n0h#+2U&*4vj6}}qcYPTLeo`;nN6Ndrb70?8I8OW46_}@V8;L3Gaf>eN5 z%M8NcLxhXcMIk1O@@z|x2>7O#0mnrUL8jUvpi>qB4~(qoOs&td{2Q9~L4NP`!?V2^ zsnG-zK0#$$S(Xl5Su6F0vu3RkoRr9jp7=IW?s)gRl+~poRa}XqxANbNKK5`f znP{-ehb4jR^e7h^d{9$En0VXC)P(H22U(ihiVY$KRb`64AR!yDUh>tPi)j7TxQ>rA&_=jCjBffZ5dRHUu`+bff4gZPw)s*d9pHv z6xn6YEuz62N6@m12W$yR3}ML2P$u9rjU~`3$2|4qTuA1`@E9nxjiFzPf-o#L=2tg_ z<-C#0@f>jwA`c`+jWBgEH z#xc8O+o~F{DviubONv`d&I|PWjQ*8Pn>VKbD2W(IbQW2^8^2au`Tp4zH&$7a8L7db zEZqCXB@Bj*g-)_`QJzPlOga}S;YMX;K2E}uMo;(6IXJQrM9#+2P6}&4`i6^rhSG#O zXTNwJ^KvQUb0Ml)eg%c)J#1d-x<0gYwE-S__*;KrLcY(nwwr}5kp98>fUWbLQb?;3 zOx(LkHEY1t@}>s(4Ik55>ooor=$*n`@V+^NWRxb=<+Q=*C|$-fm(wUF z7e}Ob_7z_r5(&o*Lt&}^2`Q_t7-qE<9#4{1bfyG2;(~k8M3TU)5bZ{?2wd*>>B58VJaEwI59Tefaa)jwZ=jIXMR86G%^?0V=js-rcmKCUWc**SGtWEH`~=1QsSY{`92c}!~-c&t*m~lJ9Uc6ir_|b zP3rYOf`N1jzwS8f04Mn&1vvkA?kxveWT#$ zH)zlxJxA+n*J~63)T%?RH?OeO-11~=`BDS=#V}ptZZB@2P+Z=-Ir*DH(SIG0WvWxt zX3cckH>M^Mi-$NeuQWJVYk7ISW4a7}_NWjL#5?I~I6|W;#TZh5(_r^L6A!E|FToyg z`22fa(2o2=S=lsu6{3xy1G*Y6pXVNWcTO0MuCo)$#pYVFM*-DKT3rtF6^e)uR@8;yOszV~ zIEO^kcU-GflN1_93YX|IS5Zn2<6#l_KBXr)K?TRU)qwc-=M)Oou1 zq@(Kli_-tXX$#nh$WPqU8#nrnUuEHFLYXKlLJmdh2^itdT0K>l?Z4U~@$s&^|I8}A zQQyA;ql&1+?Jq*!GM>-rEOgyI!D6ln88p>FTYra8_2vc1&!B!xBkK$$JioMA28L|!8SYq` z+m;aC@Zp8UM-0Cd)O-8j^q-q=-Wv*sZmxT2Vc{W6ZQAA`2*(;C=SsSv;ipwkp%^|U zrP)}8LhAtLm0n_Vppv~p!RKxO-dRh#F&VtsZ03@5ef5p_>@`FYf78!TcSf*mu@DAQ z3yC39Lk5QI=DxgviNU{W`m4a(Bacb+ktCB^{YU;FUE~&msYAg8VJX~Qr zQgLFtUaNk5vq2+PMV*S36hizzHpsUu9yo~=u;D)ktUSL0_Rfr?)FydaVIp(Z=77-N;|Snf(bfvGUkk1^qMf07SmJ2l_ib9=q@4voR#4Bezpvl8vA2$nrBnn|mV#I=O;%4Af3}oqp&~@35=M2Bw98t4 z{?7?rm?{1mi2fqekX52-J}D0A$I;EH(?E8DDEyt$uua|}DKn{Yu6yzJ?H@9Y7i-8Vcn|lmbTvTzw zel)>WeB#M7ztp+tVAD2eU(L|Hi{qEf){cZWxLdvXa1N6N7(=>$bR5>=NDIR+!~Jz& zOdPXKeE3CoEbtJjrkcfHeXo+*c&2&YW+|3ZAM>n++Y$}g*4X=p$73+eT%u5W^AE&K z54mv(uMXeK-*Vy#Gi7MPxx-7V8a!>PcW`g!{Y$MBih|^g6G{|{>#I3W!Ud4~>%MH+ zE=pNA)}w@FF#rRd=&xnYYZOk)w&A+&iFeH6#h`GUtQ1&idy3`xd#GH?+jqE?KG~EKjsu#g(ei~b1I=f!ox@{RdKQHcsm3E}6 zlA{SD-8O~fTwY~~DO`e4O{{)*DPynh2`05jdq=SdmO#q)T{W&97J|CgWI|3kdQ+!P zjVAfq(|^`l6kNh)-qyrL>_Os;0gWfPNpiDJojPSfZP@a_@H%JFMCL@^1u-+r&I}uJ zj?o&&h1melM?*|}0-Cq6fUcbwhPzE!Kb$Ns3628=rKb;7~|`^&g?Z8Q?28KL^8 z>8LzJ)xeqAp^RT({MeJ6Au~A55d%h=E-zVAV0VQ4-M_7)eh0pppl=&$(2+rTul-}N z0M-*HP7E)+mL0fn@KWc{y&b>|PvBOI*{p4!9~04)(Au~4<3V>5N4=^qUv~HDSsn9= z-e$SI^FR&DZKbIxX6os~RQl)er8S8Y_26;Bj`|53uYdWX3%^paL;22A4qGYJ7%~2-{sdnolw_t`B9XGNXwRr~!*DSBp(maomASh4Ga9lfGmYL; zO1t2BqL9hB6l}PAX7Fe}ynla-BiAOYgvj8{{kTJ5lyjiNAO-!JGQI_%a6&n0PA6t1 zH(~5I?V1g3G7jB|ibF%GI;US83lj^f+PmGKec*@f8;vZ-QBus#XgQwO+1{1r=~l?wTF zG)%*L{64LfD)!YZb_^ZrE|jyYsKEMvPn73eU2}?0>Vhd6yr^I`Gutc4+ivVLc;wmB z$pRQDZmOZ-e`D0*%G;H{_ zN4AtVXw{uZIT8PCh3lE5^Mf2|2@)h^Hut8?4~(E>l>smT#*dPBBqDU(i@)y8K&p&i z?6}p>KTvgCds4yi5@!%U`!+}e{P|i{l>sfBPpfs|6X-ex4^zHpQG0n-J9L7Pj|r04 zhe-wUz%@~D^hX7p*0#?euG2*%RAQKLpdns&BW(LrOI zsMRKQr_T6RFQGPb7LZEwO4Ml^#mHj zx*41cm3-QP35<7ARujglqvNXx12=h=)hOZ=y?GNuEaQUH zKC!f?BY9jA(QU;27stJc@Dz%L8mHSTb5Md_M9z{1VSHJ7Zf>q}v_wD*_t*-=Lb!54 z?>h}91N31)UgCY{_EAn9%wwor+v_-g+Wf58=h|~JB6^J8@|s`-GaJHCyEcB?T|a;R zoW~Vv(^pTeVs0h&7_Sou031|sLGk|Z=3aZW6DKK59iD$K?Hql^?D5_kt1J9>{;=J? z;>3rAw`3EN=ek?}JWjeHQ}Yd&{5$yeT6^L+bPZt%_iNCoQ7qW4+!mZoK#Z42$N-?` zfnW`UvJJYHeqTV>@05dL?2?oCIPla5!blxY;E!8O3^edr#=*|$xC-f5WB+yQMuwI^ z9@~30LBsG>M@tev=ii|#=n#o@gw5{+;;eKqLdm&|m|j|z|0E)r)325Wi28w1 zaWZj_l+I6CKPRH#%i;T!WE0-TdE9tO6x8=l9mVM_gWK%G9jJXVsGeK(@!5BLNB<0? z6FqBoJ@AWmg=21We$m15`()|p>)wU!>gpcg@}G5_dNDb+81~Ui#8kNfM{H6C00}|< zz0Hmby!Rd9Wn;s0giR7Ca&uuB%?h2`amdH_0!SIwCsOOb*-PJ$ITEe3IyRFEs(+zf z62P!R(arTmcGxtIumz}5(evlwG!J6@#tHsFD}=_nlmj%{-s`xZm1_M_%{A;+dCb5Ck?S~vvoU!AQs#z!l+f@u=%sgf16Qz$af6N_b%xxD96IY%3a(e3A9lm zuz@%sc3WnUA-B{s@s4=P)k}81(^igrm6<+j?AYW_dU4lMxzAuRT_}>V$U^2`cv+VNEpLJ6cI%*p8oqJeE=hn%P8wj-bT)lMhq6Sfy zL_-b^5j=^RyZn7;kI^qmOGA&os(sx4`@v45DL{Ovtjxv7RB(33nf!>+-Ru<*pMBob zPif9|J*}LZEcoy-Gqc|+Yh7)Y#%YcjzNYSpfn!RwVkoNQ4k=4G!~*^~Wqtw!wy2G^ zJMx|A8IWv0+bEIRp5=ui^ni^rqi3A8)rzM#$yiz(=toXZ+JQi6Aek+_D3U4IR|EP` zn(Jk3K``X^w&A+JpdIDnOb4Q}7?1#o;|Q^HEY+p-Z*YV7%If=RZmy_yPW$iUcltrR zE1>Sfiig+ENQ(eFyAxs;4ETd2=4qR8M}8bVbk!F~{K~ksJ(8BOD8;xlF-mBrMLp=Ykm#YeJ}gl*8qWpON>6Jo3TjX;1k~!{0nYTt_7xNb~-#|Gw{q&_3I| z-42Tm>mtVK7~BX7`c*FiCfJtjp>!Xh@+5hc9b=<6>7oHN8YtM3IeroeyoD!%*S$qP zEinu*rwa)l4@b&k7W(nh(h@b!yBUq~lw5YD%$$e_rKC|LXT8RMo3=19l6Ywzv+Hb< zS0)nhT2WQH5vo6laKP<3L<}5<3SEoGn8d-EcXxNsA@9{euF!L_#KDS)eUJn}lmwP= z*z1Sgh$8oLdw#@h?_ru)$ap1TO8pdUo#=G$Xh_80|FO#5$z0UVHadIq?siiFJz+P)k)*rAVCcJ*P8CPsCx1n!jRx2^WjcP4#{ zi=I|=krAPs$Q>%{H*9zVmLT}&mJXWzjAL+rabI46RDe)9!wCLRUOr$Oo$%o#0)lvG z00#20&r+olW+pZu7F;OvS&#Xq8*?wg(ff#^+!FXiPtUl2a3o8I6Y2?3Qi(E=E_D`u$<8%w}U9hANe3TPr+g zWGxys=1Q|P-8S0FJ9qZK(kgvts~W%E&o=q~T^sx#mPhK5J9c?GhpKItE(_je%8UH` zGHNupx88B~lsfz*8bkuo2xF^Q5383KmH+)I4BE_HYYQow=F%4=0!uh~E7(YtP*q=t zdF0;)#J%BW{GXp!6kjHWDNCvz4`mb^WQ86c72OJMn+AC+f|tki`~CKxAMf_Z?#P1k zL^EFA-Z4;$a0&eH4K4;ciuT>PpZ@cc3ft{dn{0#lX~_c@p0TDPN(JrZd6obCMD3&6 zQ!qkyrtl1-u1J(Kl!N~(f%!e_XGr$*SKdFJ>2!RJW7h72;0AloT zr>^zj|Gcsy>+<{tX>f_^e|8b2oD2Oxe1Cm)YaS9=(D9EyT(`~opAg%jf9bQv|Ng{| zF%rR2DG+>I!y;l+El&9Sx&L!F0pG-MQ2v!;3X4PyZPy3)??=tM9ggzcB`7yaWf!1$ z-D8s3UA7=debN}O4=>E&9nNuaiKTiH3rAIARi+KKzR?|qlQ>ZDq!+JWMyAvjp)_KFw~|1cWJl$JzQM|GQ@}*;P<Z@ zPlf1#QBK;lI3IIbA=d!l2QlJI*`B_(33~ZAH+z8kka84*{HVQU8bKZlyA=4P=*C`a zEVt>`fKeOf#IG%sCnB{YBfS34^RJ^a*b0)iYn!pPqUAB?xaI1+W#}h!J5nEcfB;G( z7-jlP-k;y~1qm&z@hbM3^NCGW}@rcy?OoW4Oj^e!BY0IRHVes&7Ymmkny zQTvSK~jUd)p;--h=U<#P|1+f;1A?Vy zw0&m3x*NoetFDJ-wuTz9XQ-EZs3Ea4iDakk zj2=A(+D>)dE9>Nl4KHaC;s{SCPti>(Fk~1L#gAN43_mw|a8lP6a<-#kU0l0wo#?$( zDnbC~P$nny;r%?83J;nKY6UnJ5O}lXN8L>Tfj#Vh@K!5}G~>aL(;jtc)5hLCcrP}= z_?c8vB*O`AwYC0<#qTm>Z?8Z*w-ZBVB3=cUNDv3wFwSCaby6Lpq^oKh6Sfhwm2mY= zozdz0{Y6&DBm$s`!99zu)A5L?vN{60fEy z#bQp=y$y3d_9*gry5a~?Nf?Dm%<5#YNZ)a%??su9)pxT0Ym{mt@V zOg^6fz7Cpl@+rU%GuBVc)?7UWPMBcuD#pwYs6yLJ2XdIojG8$T^kbI&(s0--k2T-+ z)1zqZ@tvaL8MfGdYe1ayy`xUw9^Srf4z4($y)yLlfXYz%M2{l~4|roXc=1c$E<-)L z5txe=XPYL05QxS-f9*ZBu&&XR*|A11GlOrwxSx{J-Pj&s!_<(!bD-Y`6k?_JDxY(@ zGX?-BhlLYs34h)QpxSha=7TSxU{2gd3w*k9HDy^T5W~MPr37TnnpvKlnd0E*ch`~p zb>6&rr#?B*zn4Hn>$uHwwdKzHT`e~!PE4Ls5fA|^tmduUy0!m{we}7D%t{$I%-fbi$p4%Kkvxir%Txr*(%W*uOGo|7fx>92F5q7&vqW@9GLFc%wODJVIDIYfDijW*kbB0^M5E-k0;RI zsvY-dcl`T6^+A&niz1_+NPPlKB?{3`H7DhW=3pf}1=-_1y|=%8Z3`Jsf9$@ z48`m_+WzM;DX397YqZ0EF$mx>nOTJCu^<|f;pLz!yRUX^)==R8@&)^Yc3?_E+ID_Q zW-IQ93w;mBk=Ypf0ahuB-PZI~x?ON6{(7U)Z>}5V6yT70jK74W6(d%a%e)A~bXI=H zed}FIog)57>Iv94RrRtC!AIXcI(LEsUb`c!c+A=iKXsiKQPMoz!u9Ld_1-csz}>@R z`<)+*s;AHp(zD&=qjY>8g+$(IJ3W zie-;=u=_k3BxU>QB|ndb!ghlDmip=j9xz4R-Q3K{nEZ6(lB1;Lz`#wIbpD~Aw~k&b zWdLfU{fptOUL=pa+IrLQP7LkdLA>cV>?!v19v_5~#u^8|kEIW^8RTvXk2{BQr+_{7 zPss?(W>cg-v!vJ}_xOxykP0lJST7X-C}$=6KYhX$8D z)_N5-JuX?nnB0y>cIabhW-#xQa8-uk`5!r%hSJCgf~(X*dZQ$dSMSoI_TDINZmrLi zF)`oSE&A1%v13bcQIjbz`^IrWb|MDp7A<*a48gw*q6_|uSH&5fWO$Tc*I~Al=5sMg zl~6t8hT)6x%#oSV7Q3X1~aBNyX`=l%eOK(R+0EG@;VkV~p!#=B& zVDb{g02lY{%xV6hT0YK369SC=gM+P4?RvK9CXvu#Ty5v={BhHF&QGh8 zx5HoDekERjx{=G!Pe3a~x@W;PBn;8O$f*>wM(9N(r`tMjE(RuXb2&)JLqD21iy8z3 z(H99k0ogGJ<8!7cEaqa5KU7rc{CG?)+L-{H$!IhDv(KG8IT7ESiy}?oz3H0UzK3%h zdEzf$&!~`|?NR=7>DI(^`QGEu_YiQv5`1z_y8pHN9t6eg;dtV5@m2mhU9vL!G#Cpi z&z70u#!M=PjVI2yxTH87RP1o#bH7$poDq(galw6b&d11cP^a8;$qO{+9;&5j=B;1wG|AOH6zi&D zcWL6m5qqFVNJ#{qr>K?N#LZssmbwgHQ@Z0_qhUlxLaC+1A{!}PFg!;u;!7ci;7Qwd z>=-31)nn)KbNGRQ6veHgCZ;GyE25o&!JTQKC~uI1KqQT9mXvY|{h$qOOE2eV&9;3; zId0x^lf*Y2h9Bh&;Ap0$rft=37x~I#Gp$#>^DC=WoyhdXRhL6+&!!+-<}&x^Rek@o zV({p-fuYsUC;!!O-Ls%P+I4;AnDzm|LHJ;d@N@vpgbYB>S&GBccr1PQ^ujux*)#ih z>{@=z_A>;LXqJZlMt>E?KDuC>VjvQQ7t)ApE0GOpi|Io;lOynyd(I>dJrsw>kdkY; zCx&SYzQ=dre(%qr$?@3is77hw2M0z9H( z3>6>%EgR0_oT2-Gpp2ITtlEZbtyumvjqEZO@8-T}5jfSvHFX%Y`a!(L6CFMAObtCSL79AUxvLG=E7Tuo!n}Q!|@fC1p#0>#C-W z8`B@J{R{QWM5UX>9?TR2XzyL^A7_p`-($t@V?uxu+K{aJ2W%Mc<2O#9)YO_)o-C*L z=#!=nqqr$@l>OF^EdFZwPIwJeS67675f4OYI~m>J3PP=tRW;T7BVK8D8i#T!AH{{= z5yDA%DDaw^ON+F?A9Ydl@=O2dxl`8j^E=m};L!2q+p~ftR^*CncgaIOf>jys{^O^A zTKFr%h0r?t!`JcT`vWPVu_(^e^VE1RG6U?WYVlWP5IUqbHVYMl)GQs|7Xg`ErMUYY zhp4@`ql;BlI7<0R*n-^^iC>gH#3JjbJN^-N%s=Q6omMbyS})xHUo@D7rT60mU~n5$ zDAD3v^Klm3ax=AOD!D)=BTsKa_Xc}{_iuS>yF?IY^1P)$jJ#Mu=GwX zR)3X5j8lpWxwM2qFap4k0ZU;POi5f~6xLlxhY7Ey+1SXdzqxgA9YIhJIPJCeG{=T0C%@BNorhZIf>Y6Aeq>0IW%BPo zk42WqDNAyJPtXAw0HK%w$1Rt@Z_dJnCm^oC(U@C>`~*vVTxEMTs7)`7h-ArFyW^x& za;pR%AQoeKs6q2CL#I&J8&G(o6VZOnee>&JgjJ`oUsjZ{59`$;Ui-zr`UuU*C4 zOs2H9;;oxXG)6CBe8OMN25kObb~SknC4CjuRaH7UH`$zAO8UkYd_ciNaK)-&r|W9% z>bzcn>eZWvzFJ}T$KTD;`m6qkJiNN1rS|gW-bZF$_1WaB*^`MWF8NZ$^!?QDjK^UDGF6C+ zFo8>O2M&(NDlDGuH$u;&zxH^$drpfqvO#CUA#q%E z_(ll~D1HQdR3ztxvc#@wrA7V+x6Pl2Z@fomRO^LL>FkP>L(h1oy>qieqsiW%M$}+! z;RsoVenj&{W;ec}a4eNQ;XbbTUYVDk;X<4Xps_sk6N+|X@K}mFRH~#s%LyJu>JMe4bf(sOLWIHTtVLPW9EG|p zmy8RE`ojgGyih!#PkWpIpoHWivL;+%hAkSmM{IXp%FN}+@ke-=L}%Q*dGp0rzLcr~ zh#Ye;#V3|zwvjWUD0tfAP`0Ue`QiiC2Gsf%y7$u#SX~jozEP!{0y zAIzfcW&V%KAP~N(`h=W7+=LJnAYBTJ>*TX7tc z4wG3r%qm|MsSMv$#Cz1@1gz7~hMh&09R%VT@QE^tjarQ};N{$~sbS&&T;8ji4*`?MEzbxyPH0_}EeB>hCGA7?wN!Nj;6F-3GEfaPyS zoNTGA9D20#^)^+?Ff7N2*XV-uyK9n zWbf26rtstO9}twrtLy9OMS<{3KuST+zjCPE*^P(b%*WjXN5>@Gk%_=$qg1b?&%pat zm`Cd;TbBX>3BgWSbUf*U2M?mz7~y&?!MivV{Vb%tSGyji9nks|$|Ye#y>dKBlU#VD zj@v^*>O1t|mh0-T^tzclh1WF)G{4IbJAwy-zXxCT*w2|0Db!o?R#WreNw_Ax3H@@*X9-#lQxdN*ivDt(*TnP z%S!+PvE|hgNW%x)6Z#Pph1}{5{|NiplG7x3ZYh15)>yQ{qtV|Xy3qq=2mBDsSPuj` z4q%ibGd@2bO;3{0yQtXB%!y|J5VoKfE+BJHr*!ctBesE_@L%2Ssf)upM*H78E%IAgZ3>vTS<6P zLeICETNgsGEgq#@+w06V!ZvQbL#J3jxqS~4yiSX;gE|+3381VK3Wcx&3iKOTLE?^( zqd?`VajoeE1)|yK!DBfia!8^B>9@x2&V#*2DOLrDVUjF}em z`&dTp;QkyTI&{Xb_($OMV@U59Y_uDk(4w;1{d?6?f!3hn`(v885#W|a#y6WDh2L@u zrnPT5VYBYn^+W(&tblt(b&m7E{fne%#T7S^%rGEeQq0Sr2Yg2>RmZtGw8^!DWz3IA zkxhg&Y@mfpYD(Xq zFf_%NQW`d`!d1Hz=JB+5=>?Ej_&XhX^41g*B`Isy*XC}=bD*Y$c`&8>E0rOOg*iJU zZjhr&+ubxGLvs)c-vRug6*Bd8QBfp& zYSpTMqz)g4ko-~_w;xIy)H!+qMvF2J&GeEHxb2z;{+0f@?XOrzLsdIKilt-~owxeo z>*RC#lqCjh=)-Ax;^1}ne)zMkI0o@+$PcB47z4H(#81Nrk0hki%Fcr8trpGM5k$z4 zN{**>lX^d5<=TR)U-d#^S~LF(^nzH3$En8Mz`^7r3zM;6Al#rB$fE|Z4yE=^j2B2b z(Y^*ZCjeMvuNo>^;(BOsyG`M6;EJjx0sp)L+dKGyPMdsd8fcE#~;pQWx`odHO9T0vtd7LD~|d}p5HsvOlV)0n6* zmc9dNC?F3IlPBvzG*?(UYu|1n?0MiKSkt?t-EEbW)RFxYdfxsOpF&PSEP18OH>ma* z=9klP)vja5-KYMnuQ2kc*%e|;?yjc$=oLr^M|PVo-x*#N!i*^%Is%N;_~yd8_bH5p zUrX&Ei3@H+P)uI+I0;baY7UQ5nO{TWAa0G6$p}NYY}+z;V+faGV~uNcwTqWFX&^Zk3Dv@({E>|B@HZG=nibKi=a?uA zhotv@+5ysL4!IuMs~zL=Dz(kY<%PoSIUV|@D?+=rDswq9hi8H?2WDe?Uv-ey&)V2*I8z>!kLC1$Iy5%gUF2g zoK@p;>6290UW1`Vwt9znM5;e+{bV9eF)mULFMwVpj<%R)@edz9seJA|Hfldr{u}xlSkbQ5ev3lc`5`0rgBdizyWA9Viq(^O=BLLL3 z;<2f-B4SLRD-_FXj7R-t`AKOTMGK){h#!93>kEuhNd(C$domnL+HTcX-f0ve5!O|w zYC;dQZBRRfGFKQ!+wYI663upfrrf$e8LgRs4wpM9HESrz0yo6J;@kQkEMXNEOJbIB z{t&04g^5VL&+XmH-(lN zQQvSmCqpH&qE&Cs&3!Y7rv_833r*KZKX&6A4CyLV2B7vE3Ym-iS<*;4@E0b;jqqN1 zdK3Ucn?GCtKIO;eP%}mDYIlVkd~TmL4RO5mA?~XSSQA&hW;z+RLs9dO+SSd{uqOhg z2=Ev8mW9Zx`SE#}kkILauFauPD!cMHT4DSF-cPS?-7b43%uJ~pV9*h~t{;XfKuN}b z0+*O|2G|?ZOXKIO4ElhjYw}I z&F9L>zuEGn@NtgTk7}>3!&de|pHK*@gjYtNlXL4RoO%infdc6r(`}PZ>GbB*e!Kud zGc+Ev=1kP+CQ3Y;JbZzeV5K;LAkf-($vE!z&;Igzun5A!zj4VAgE2`iFR8L|+yJ@L zA3u&;^58WF9?Fk8S3%fZr^E7y>l~oe@uRg1u+hJTr3pEgugy0VHL-A0LeKaeIyyBTm~=$YX8>z4Odv`Jo4r!%f4~? zPvDZ&nXkwt)L;Bk`n+=)6Ioe&=g(~4(qQQ;cu24BmiB#PxSu-XSK-veGSns@yaGxc zQG`{m37o9R|NO&rvu>6)5S`HsvP91Ws@ZhMMy;y(A!sm@;^+x2so(qka2pruvlNrf zcA5H#4O{U(rIEjs#+(I}BKnDV<>4fn8}yEECSWMvRvyk}c_X?NpDO8U^=!0>QUWpe z|8tjV@`IU_naWeumgwrS$~$)cVYqw>VSp78mJS`ndf%x|OeR8LC#Le)#bVMh_P?N~ ziI`8d*gC2apoiu0UBMQ~O0_}^{PROU7MYZ_RUD!dVol9Hvd~ngN2JgRarAskK32}q zJe!=F8>U;f!{HFMbUx~%7d2Ix+Tj$-T9)1XhseWpTXjWoz|s97oLeHW37}QNpCjH` z;0{$|o#eA*ueB||ScsQ{s8$iS4@u9f`y&>9X~#0NCck-=JDy~_rKoo*doxGkf?jM% zV#hqrt8l|c+b6%i);A5lb>ZWQ-3?Z)K)`3gskAZwaGU1F-z>ag`06fKnA~Vkm$dOE zpJCd>6HfI`Nqo2Z%w&4b5UHLhsjdhqKyQtL`kpXoM7L@ElWSM9gG26l#FLF}ePNT33yK$F9VPiA8UOmB{J<)%t0 z*!zfNgesZ*tK8wyqenYY-68h!ZG-jeD6cM#9j9o;L{M}StTErx%>uY}bjfvP#g5x% z3~m(G9lH+zsYrE5odUCEhZa5WJfWUb-V(fMUWZ9}Wl7X7CqcdyuB&&CU-0J|}A!GZ)Ai5|T^cXnkUB+bzAGgt3-gC&1 zOM2y8u{|UlD;ZrXt9t76=}CkjI@HQZmv!IMg@?kSDn`KAeB#6% zSYSdOS9d?D9#KvA{y*A#^RS-t_W!%NW`@fQvskjnP!U4eLeY!~pR%>sDU>D3maJ`K z45LjxvP426Ei{%85<^9WsU(Cbq7tHol75f#U1sK*@9+Npj{CU(xR3jL+>YaNjh4^n z{dvD%uk&@D&*$@ep2dex;a=WR#3?)^!5?VQcwuzBFWG+#q+WO4B9swPl972;Q`%6zIyw~k{fP?=Cr4(KKyG|fXnvT3I5Gn--x-+p#l|dAi zJu2;HV10&3o>hh6a%X9Oi)6#3-PU3vGIC$gPm!^Wa8w1Na-kv!y2dzvUP zihu%J=qeL5nFbGmI|N4O$(7%kyGO*3cW!*H@u6g4^B23M%kgz^`LH6%*(f?nN=qlDGJ&CyyR& zjxK(2`{*#2f*lX0X+qLts+_;X>7v1R?2+x-B{(tjjQs}`Jo(oX(Q&ig94ThOLVrS| zm&KZqZ83UhU)3!aS*>2aX`Se0|B`*NO?&c`Au%yn?@hr?-I5zkYP_H+|BDiR9YyFF zMtX`Q|Bu_1$H?}~)oa(T9hIk(SgmUqS4uqD3RbH!mX|%}$T2F-!G*+2MXifu`!jl} z;hCbvxrnZ^L3Gvd*y(&b;$W8z!pt*iG=vk&_b5xkd*em}JuP$dviG->eL!*Ku$N~q z05H$qm7#JlRNk|M1NZXW?Rl+n^15lK7k>}buo)DVz1pR4)kRCIV3U^ioCGNtWP(j> zgFz+_k>B1*KzramY*zBd_LbwaM=o?jTxYV?%&E`3`4S^kggq%QhLU$1V>(c}NN3!eEllpJ8o>BmV4ZZbw$-D1xl{b*?d z(brt$8NIHe&&J*F-Y)G~SeO^<>r&aj^y9#e?s1!GguASrk z!xBgR&eCusP77ANL}xo1*h73pV1=-!FK5d>6|ucZ2l))9DU){C|G_>3l`EMhRUif@X5FIbt1hi|~wrXxc@C-NUl$X9Dn20SaJeiQ__v_}YuFTPy_i34~0 z*>ATbs2zE*>S=rCeGGjtYri&VOLt$5k^&U>V9n5)h&VHJg*IwVkykcQFll` zDTz74Y)&tpX}<ndTmHw{7M8z{=aG#-cs!>AHl1!_{9nH zUAmGWrVb-3Y@H+NAyoGP*ZJ}(y&QTMR|SKbstNFf6XwkIMV4hF0&epEe#B_^cTcuU z!>2qnG2Fq1;bNypQEq&BXP>$L;FmVEzoJ?hGG>&tG?0UzHu%elU9~t!|<#5s?z=q`*FN#|V|6J>>t?W?a+i|JY`1T~=pi;LGf& zX=iBkDqTtXEd)OZpuLmN#*OL3!pzPlawkli9w=Pz1uZ=hc2k%Efs?5VZ*d#SasC>M z+(eLjp@LY8UDiCmJ=j*d5gRqrz1%iGFK<>4BXr5#u*4?*ih33nM_QkOFC6KiMT4=? z<%57rqPAeZwGU^2^e)Dtni3fS{(pIwSJV+FLK)Aw5u#5SK@uXJ#-fW3ek(F&5ww!h zN{%L_C4S734&@T6{%}T^cs-GA`nDki+elP!2I`v2#{0((%_Zx=RKc#vDA#0rV@yv^hmK9 zD9_?m2LQl*Nr;WgGi>9S9{LNy znr{~&D=i0;-`DrL(p4HTO|x*f{7Bs&uK_LVEPPCN(x_RwV7I480N1Sxv+6HD;C1td zI!~VfT?*8LqS*v?IRitdzL5i_BCc+nirqj1eOoJ+V%6zVNnQWy+ZBS>BCLtzuqsAQ zA4sAqY?07(GR9>$c>_sr)$7z={qW~5^Jt@7$o6q{c0yqr zVbk>x#<2MTY-+-|InNMSZO1Zgzz`TO7&Wm`x`~MB39G}ePXqaV?O!NFDDnhBe?sm+ z7Sc8cZ@^G~T@E#2=O|H~Uf9S_YqU1y+1zBDlQb#90b-VrE?q8IQj11``ays((uwis z-QJar5!_4T)faPmi7;*&Ab~V(`hA`(`d49gHg4RkEs}_aU$q;Dvv^vaw5a3HSGWZn zT?eGvS~wrQAeGzP!Ss}s+`zmg;tm?KLaD~U-#qkD;PEf z-I;5I22uJ_lk$-hSPNuJJs|jjDubtnr`^Yoru;^j;qe=b-QK@FDo3{QSGY$aPA25~ zI?L^2?_b;fr)E#X%Vti+T6^>$u2jSMr7Q#qPJqx4ZZDS`a_3G^eSNIFc0}59i^%>s z1&b^V;8wI7QV${0@4ay#wmYak9qRA%$BSv%ZYYrD;pt~5*&{vxD&EZRR_)aN5qBL5=k z&M)b1eb86JpQfYoUBg47%s`JP)v&CR!R_ViriPQk#6ye6QhNl{Bw|*gw>)<_+9(B5 zZ@0uMkOV`Nf`!{V$`3x=n&GfMpF&c+i$rtEQnAEiIJgxNOdP58Gn9|vph2E6kwxY( zbLKu;S4L6?-sHpj)kuwj4Tj|NvMlTm%FB9P-Vy(>6@(x*ZX(z~w=7m{J@oT!`|Ii^ zfy&Dt!vGo&MyjP}R=_McF5S;Z@V+cDMM;x-=;M;YxrK*BAFDlj9jw9*{3?+4CG^wMa(;g`5xGZrkUP8kp&Kju}d1yO*@f^HiZPCe6umYXiB<@6Vy1IdCltTc9HlZZYy%eLxAdBmaf0`O{B7 z^|NCbtQ{J8-vM-5?5Nu-rVXr5%(E0x9*4tSX3`)Q6k#$FQISd&?$(m`Ql6DB#BimO zH)NLX-OrL1z(211;RZLPpt zWZ$6_GD#Z^S~b24-aE)aX`UbH-B-_u9)>ZrDl?*^rZ?{MG5LL zB}{n;S2UzVqaKyEeIlKzHy<)=W?gCX*sxk1DfKQpBoQ(h1*E#7jv;V!lL+OuKNm5*|B#ON{lW*q1(HpdT3?$FA8`hB93~ zDzU<#K9Dwk48T-GWjy5F<7RMw<3yKUU$p4WK+H&%3_AY4D%gX$@==5S@g8KI|#jjoQeVt0DR#{Az4F( zk)UcOl+bZq(-c2x;e&#F&V>IvEs5;4dSWw^Hgai~5D1E#0?_EQh%Bphc$}LA|Gl?`vT1fIs2qkCn zRXr#i-p9o}i=|7%He*yWa|!~$2*#h7mu|F)r5bL|s=|7Mt24sIvMV=c#L1FrBVz>?s|oYavV%~!5n z8!eDMvN?T~=w*OqbO@Rta79q~e=ZIOK8;8eh#Z|6#43lnuCDI&eAn&MoT+31iQZ4K zPutngLM;tRN|)S&=*fMANre}>1s&#F@)5qFy;k9tZL#D6N$cLD$0(#sZij>A;yUi{ zz2;&%GDq%=I74$i2l(x7JP=5(o9h}i6ZaIPMvr2*#VSX`xmCk`TkS*Lqd)oBk0@<2 z0I#H7ab<@>eFV$YqT&e^eEx69Nd!_NmA7nQuK_``r%VguT?9qNuwBM$cZC}xrIOZV zY8U6FE4s>+;q;4iiz;Xr=h^PEF*=rX{7Z4bLX3jNOe|dC>bg~%q}4vbzhupLw^{Ih zEf!VNtHZz%J(N}5Rrz6`KXr@CPzv%Ir;df4ybAL z85zb0z9%HIuOirWscvnFXe*V4qCJteHnrKh|2s13Wy6b5@%<$vz)$I+kpAXk-QnC01 zZ3^rbnvL5ns%5OhUi5>yr+;j$+{@gn$;l%)uD0?M#K~o~9!c*PcU``1)(?G}cJ9<^ zRGzKt7`IaS(O0^C)?8ddubKC+ze>v0Tx;R#4H@k2eoUkh+)@N=A7j|%dUlA=`lO&h zm(F~&#Cq2+1+RFd`wJ);n_R?NFhPlhn-(P`U9NB(<>Wa+-e+aApqlAZh3cfF7b;nn z8zho+e4z>uy9tpv!V=7$n><|c&3a{L1uKQ#yKGEnm0)G4RXCy8Q`f?75Tw+j?%wCC zP;pYfk~#Yhc)%CZCd)H3y*qMC&W~&>X?K(SK`?bO*p%#sfU}i&3(m9cW6Bk9;lhQo zG2i(Ek8=)S0A(pUxe8KwJ81v{F)kCG0oI)AFd728vGjo(rT^KlRL~pofkf2?%R2(# zI8a+@gs8I+6J-75SVg{{rhV~Byu6jy3Bw2*s`6D#xfwQWc z7jt6_kC&Q3cQD#2K!kgb7pz&mIwdo+&(-=>Z(hDOAK9keG@CF|uX5W@nlyV#`;#%UR$mD|*RB)p=ShI%(6Zt~u$OJt&~MtX6j5 z%034sUo^nUQve%C1Y8<~bUgq5wS@$7h`vTiOeUJnY1GDK z^|Cf|l$>g$KPiy7TL~~FnI)mv{w(zoF_<;qL$Ex3tNQx`!ht84>cqoVlmbM?df(2* zX`I@iVA4r1xiw=+yQ7>G3s3Ix7xZ3|V4ymjivT5pG#|-_Jdalm;BPl#@e=qzN!c`- z(sdz&r$sdiLNATPVd66027xak+rtzZ(*Irm>j6KtYP|ll>JT*2lpftl#~pVlErK26 zdZfAkMAA#U>X;?oen(LR$I$RT9}pzQbvGzZ@y}lm83!W88|nOfvE$I82`la0-P|DT znz})=_k8e%K!M}y5$p-B(uzQHuVjsD2MUhJb_}FskwI6&15;Y)rly?lilYFC;CObd zyed*kX-XCBU}=IQ3uLe{ z=`n&!=0mMKTHAG`7F7nA zEAQS}A{lF--<{)2S(AgX-Iqjz)S-3G(3&!9fhbuXqcaq1J}=+ri7jT+WX5on4vutK z7jFMXF8*O`kh?02wNr%L5aD7tg~Y6k6%ZkkLJGiv7TvaZkl3v7c3{77GT7joL^U~8}CjI?h#N}_< zwxplMKq|NB2yc3?y8#D{qFyp`AHA#pR&92uj^24>DaT@C`)wRbRwbnqOcreff12P* zB2V*dqJzm$xvnc$gk^-gsqZ3swKMH}7-Z>WzhC0fh7_@*rPW$HLaf@jhwH!Y&=U^V zCA_KP#qN)`;fRRApfGhygl-c5hNRY=PwMwC7+)&VUK-4oayJSrU}IA9$?EFr zVzLM%Ut@|gm6JoRL?i9`a{)e%9m^p23Ysgd4q7C-R0iC6(_3E}#Y96(&t5NPPl<>^ z5`BIcRY)(1ba6qP9hNX+)}jD~R)-ehHlb4NIqmceFo9^HXZ0huAGa2PhQLa6RkUb1?0lk>WXcCWcQwu9Y$Ra5 zc=1!WbZOOxHef$3lPQ#w1!KBc!k+`tC(3~z{EJ>qZrUXn&u$ZYW=>uu;%S~l7o6W)p^GN;nDSxdFE8H<94cW( zBt7amFN3CH2>EwRpg>jUi~KzIYiDR(gCU5c$Gtl&(&f7;BHF)le@`xS;Ofk1)ECPD1!$x zdW!>#f&^29k^No0jU>fL%uL~=R~32ok*n{xoAkGkCcgmE5dj1_uSrp4zXN@%S?8%| zWm%Xun7eI-Gq_GtOp^vb1Jp)y(aCKW^ll~5V)X7o<*`Q!F%97~8>to*q%$Tp>M-EW zNUnna+a?{#dv1S(; zBOkX_rM+)C7CMpig^6Lhe4lM6B0&Jb?JQ>LooPCh!RJCfAPu@jnISfaqT}KDP5?gX z>u@C?yXu2{c+=r7n72rJ&*I%iA`Q42CU5~K#vqUPr#S--0_URcu%mD++}%@tspBWK zaHGMRknyIbEtpj#K3P%^EXgDF2;&G2OKF?TZI}nvuj|aZ5m-eFZ<2|~D{p#z^63fW zsL~z7l9AaK;*tx21(UXmT))%&wi~L719@TqwU~>NaiqbZL0}(SG5Ba}mI| z;g(A1@Cw{p_Y)&q3M2=!gyt3+bJ(z9Ql2$wZK(Yx8alb6E`feHo)^o4y$}$Jt!+cC zF6sMp`qD7T%jO8oqBo@lV>{qBbf2Dyi#uR_OD@Q3d;QdgFdr6B>WtpfLnIU(_}=X2 z@PwPS1*Au6;pegtSbOIa<#~*3S5G3HX=SqdMB@g-R);H}d*$f*u@8m66HF2E$)B*b zPQ*{owpZl1vC>^`m>Bihc!6T?$@Zwn&mXq1ihua*>C?GOAyAfxkeXl= zCz9(scki~ThLRV*8p(nYVC}K`h>-QCx7MBRcCopX9}w$ql03r-nZ87h$0(B~b$9W~ z1G&y=u+~XqmqqHjmbq8CLG)v8UO!oRDC$&B92C8xFWi{K zX>j+nol>YgOt=VoBMmBj#U38*p~Ze*&FXy?YXcXs!5*QQ_>G2hbK9%Z;5Oiao|4Yk zemD9H|D*h>UT@{!qsqWB2%MQLV^L@Mt$9`>< zGy&tjNMN2lb7qCUj`DzqFBah4&grGyYaWNXsrg@eYP=wc$SE z5Ku8CyYavOJq~zJx^MV_?g+_51HnNb2Cg+93!j*F_E!SFqd#Cwa0Td>uual?!NHza z23KPAAdgVh!ur)H`P3cl4cA;O1(uh#^1dj?2IJr+1#ce zqIefH3e2l$`=vXF<^-y=a+a=Wb-JN*3(arL3G$$VX8j3jrsfVnCXrmYg1maurm@~{ z>LBHx?nJo>FU+u)rha#5hT#C6UC-baJ}0({4oW(U$hqz!6~Q0GBS~CQAh*-%DC&;8pPVNL<;Su@Zx&`jbz3PoNgz zsW{>}!G;D4P$XVi5Ey*o2^{&}F-t0Gn!mAqoZqJS_;>@|fplKRfWCL5$)Je*=$wHh zYrKI}_{jZa3@G#;TVmHu@QIJ4&1Pix`?Zm13=@lN!P7xirHeV2lukglIMe}D6ll0G)rdQ6 z>^bDz9#I7wW)%M#H)nvkWA96)X80!vqbB`Pq$5^N59iZa(Ep6J(!76;$ABp0BfRu;TxGdqCMjI^Ca26W&_35TKyQ4`p`GBmpblO0Wvcl z`kajJ5OMCwwvE@{{7CvrlMSbUjQ!)_)s=~&obLvAGD}fG#>BQC!D@m_w3KgfFj^5B zbxiDLxq@p;?+xv4Gkp_at%N1&(j|Q5cF~?jh!T^Ga942UF_hZ!9$b8X$6O{+iYXG` zD#rO-Rr*G5ua~cO)^-~|X!cd87shdLDg_C8M?T*ev5+BHh>4h(u!^k-ReQk4r13Wq z9uPSM){^22rCdsCYO3}ebR@#nddDzMT@QKz@p~%lq@c=DmGQHTz_u2F+tpE*q@#sB zbc+|tZ@-A`B8{C7ptf*g?sBnMED2I*I2OX1_jBYF!I4Gsdhy(^q0O0In@Qe#lx5?f z9Wmrm>1-N}EbIKmn)xlbe@2mGAjN$~m%$%Y zuwqjsdXZu48(VVxSig6ZMb>e@0ZASZp;)}Aj+UCS@1^&O>g~sR_Rb!jH!p3zui(y{ z1mP*J(gOh*bl&hf5?xY^Miq;0|Ms0KdjXb`l2G0(lb3$zx2?z`^89la!Ro3q$4~<{ zOeVipe`3gw)Qg(He@-~vG2}jWfAE3odOoSN_{RPhu~9o_Hd1dM6tOCVks#;7FA%l(7B^!_pp|%>*}7mSD&+u-=$F7sbERSRX4yUpfkuf zDiufr$_BY!Zhc5le#I+j@{a0%5lO`$IbRwmsFh^0!aSMDeNV(5;5hKvmUX-Kg47i z_gcO!ii(krxf*(%#r;$4=QzBd@q1M)330xbg3QjneCt2nY{(@82tv_24&S78m?$R$ zfrJwx)7eUwpkS1sITOff{wc$2?XPyQx9?o%KKKk>rR>l6c$4b>$}ypQ;6EauVwb~2*xP(*%|G+i;CMObl&l0hnAnJ)VH>|n=(c^N;2oVXIQ_i!9LRo?wVp;H8n6A{!OlyRA;hfM1SdvBDdvRbvNu<@~ z1HinH;*uVKcsOFTj~8wRh!_22djma{%(QM*^&xtda*PPb=TTEW5fx-1&rifJSXO_B zri4`M5P_DE^VFzI{AeFODkzltP3Gc?2g!22Hr z`I*mvC~*rEO(2R{bSAql^OIsBAl#eCXyBfun@&2ncvw2JlFzQV@wlD*2PLQwqacHQ z=`Im0kw~0L_ms5FWdFyA$E#FXFqsl{p!lLlxjLyxzE6ovNX|p}04oOvhjV>CsH0$H zs3Kl<&!=2_40(lk4&o;zY!ML*P+$urxvCn4WZg(dI?jAcINPjAiemBW0G&6hSHR$@pYi%s8f4h8hXHvS| zfq82;^hiCwEolFK|8c7#_Z@V6-)+j;_2GSERc}mJ{@Y^kywHb@9BMwinziX{{N`)x z^*g@Ic3m?4b@bfBPi=E5GS9?kPucY1eDy+?YP>TJ&kb3n;9QCZrJ9;HmUcB!3KG2g zox(`ing*_8zy7(OG!FF$u2WG8>}|%3wH)tTof}j1!|&>-N=r+xf0fvyPakVB%%b!0 zabcn3(Q7O{s0FF^xjENL3JVJhoD*m%(NT#brmq%+IP- zCv|nM6@4L;pU_VXc{Dc~T(_M0`(ecr*9r}hDW#?!e_O*=*!KPDGve+dAEiqda=PN; z1ukbfw#*VTDtbBY;jg>lB!^P|0)jeKYz~7Myz4ZS(~;)Z2p|N_voA|-8SS+-+n_P& z(5T$6Z$JHEA! zLo_PLmDvY7cSgb}sMMn4sLAh8D<$^s_J0#8FD)p`8PCBJb@}p-GNFU)U?&rZMg4al zr>8UIrUK-NubBC$QOlNhontok+XwI0J-Q}rQ+mva8qc$1NpsD!5>6Y8(%k~O;4@fX zmBA-+I@!|X-P@c=-K;8tQlu?u{E?RppS#rDj+0Mv!Pu%_`!m5))4$W0v(-`SQDGf! z8z1~(T|P%#U`WV9+v?xwX_sNebh<5jc;ORi)u!wjH*H$*@nHtXW7p?h^Eew4N--f; zj=+jQwegof==fRQ{U9@dh%AWOo2eC^S$SeM-Ny8i0mRStZO}Tl((vWcWUWl(ugu+^ zEzSTa<4+vF^xHRhxp7QOqpxPMg}IfjZHOrMNc;|o(L@CeL2pjxnvWjMknobbS)mEq z&cwpoAI#j%i}q~ZyzksOV;_ZQP~OvY_QSHZ_Ol%fgZe@w5}1V2)XX7L9=LaJT2z;w z3{B~*T8*&_FV@ec8t`FjNXQY6Dde(O?tMre`oR+NR{C=Vr*o+$a}*Rz?7#@M1r_9e zhloe!78aTO?zIxY*Gz`p(1eOSV_>k-o}ygIIFE%n)dR_}qiA9mnKjA^1cjm{llcK0 z+FMgw0oz#*AHH6@Q1tSr4upqWsBvm`>Ur0MlBo(6Yd=zY`DGLonOvlJ6qvky__HjV z#G{6|4JfIho~(NL0Mc_SFG4$ywn~Z_@+6U6MU}T| z2^m2VEM2vd|6;m>vM8K@UKBJp#1gJN<%z<#D<4ZECDmEaXh`Kf7qP9vDNX-wA45Z# zV6cF4XEkEaRO2yxB=wdp`lP!#b4uIX3|+gQWh3<^3;#e^SXX`h(GXWCuhe)|KYsLR zrt5o%r>>|HJ*#ta%P{C-gIfs3XwfS^jH-M}IT=~949tSwnMZwP7&J;UalUns6LL!V zBGD$12Z;?WsWQKAO;207To&!9vul)=GUWq!Eqz+HRKu2H(oSu`4~bj;RW zlCn`r)7YxPCW<~v`MTE^vVm8xK689S-OtJ2s&WQw@ zsm?x3Sy*PT_stp!K|qJhr27e#gLs0o6mAse%e^(7UBDB=v8)WxcJP{3%4 zy#2x^Fs1CiS541#HxEDiM&rwVJ1ioD*O;8^Ht4epKg>R(tJk$_Kva~A`OjlK4-_qk zi9d)^yqGTeEFg+tlXl8Q%pv5e7oJPGsYx~BvS+Fzd4`zkO4B4OFN4@F_1YnrPjMlT zZ-x6Rj+Wzp%G|SOQ@@dEE&;F6eyQ-*U_CpUNltmy0UP z<=N7ObvekJUyp|@3|u(u;?a5czfeJOWbrhdEBXYDy8e~I%#6F%++^mYdztH`Z-D^c zZit5zsEGC4DPv#-0+=?nF!i+Dg>gYfHCrQrkeAosSa|==KIPU#28H%(?tFOrO6>VK zCGSp~d$cbFuZ{|S`&9i(mGn{J8}^1ak8dE%?}QvHYis=~qdbm2DGs%H3eQgZDYtBE zqqFft~fbG$JPAQQ0otf!Wb$pJF0*TN=1>`%7S2Wny7Rp9IV?H-CR~* zq(m#-mNQy+OXExR4EJJhN{;TGJIe}|W0uAVfslEb`S2yn+uPF#cNYMUO&?R}re5<`*EM0wZ z2}gSWoavB!;bYq4z-CK&aQ&;l3URMyMs&9{YnK1~yFY(X!{(dq-im^{?n|awW_6&P z@=vczUH=*nps~2s-*6(0ulCpJ6_g*mFtc6&p#00ockN&2jK+p~se$sXhrPe>SN=o! ztKIPbL2vfk{OOAufA!V2s}9HBWx}vUeUKJ!paYDSyWtljlbysAij>{vc`G_c#10PR zVmwbvsBBmS{gbt94XHFS-@i%~$*|!+WSeOzj!<)M3;-k`H*udVto zNBAK(MT1#)Q}#BFEbl(v!X8{#9k$)hsn;v zg49CNGtxORC`!3$CROy;h1*B-R^pS<-ht28DdLmV!d?fv1T_EYFl};uJQBj z7N3KPAej|&AYm-XI*%2(JV|Lmd0>I$ zjNG~`X2V3z%pKzo7H{Up$DNr_9jM63XKa{Xvo2pcG*}#kqClapL{kzj@D)XwFo>Yz zT~#SpuWr`%nYjN!T5@u7s9pLVi=2cogN1)j-5X?8EzerF?lYg<9Bw-UyGE$Xu&omy z9%^Kk*}8(?J|6H#ZR1nk*v$ODQ-epAH}g+iRTROHCQ8FCWx3&B`Kq;?rGsU ztaPbmm0Bxje}>{zqFm0Q#*xk!OF7BKCGA>8OWT#-tNY?A;!+I-qpiXedGpLx1r?Nb z&)|urDB#V~N+@x}cieV$p^i^>7znPNST`^vFA7F>SI!!|b!j93!!yUXYQQNuAV~q3 zg>!JYf4j*Q%@lc$aserpvHC_ws8Pfn1SCAfqsmb#XOW=4NT~H-NXT%pW@aZ`tDz~9 ziJD3cgN9rZJ}vh$OFt&KAlnqjWqto#OoMcEb&HS|W6K>&5+6pp@{Ub$bhXEweqWhS zyIB?lG1m5npju?yx)ibEwgCwu^OXy_^i=Yb1IOtDmS!F>9-+aAY%*RdCpKngt0@2{ zK=9!b9SI|m)6briSD>;L->v?Y)h0|Nr~(nQj2=ukV1l4WK+1KaaX+)!503kOF>M;X8d z%ZgT*X2`>Vfz44>{0PF#{m9U-6tJzfO_F6@pjImJ??tAYPqn_vVptcD=aRzYp(Ro<4zw>R?8$+2~*$O|_Kqz$TAXtGZ1+`Uh;Q`f<_x z!SO@rhy74-Dxr`Qz>;T1bI?ANHye^A%zCB#T9=kQ^xqm3q^J6#Nt4*Sns}vwo4a%v z@g9O_yTjSNdMcrELRJ^)`J4oakj$6macY6@e+jCcW<8zXZX8*i0Glg5tvLIQ)hYhb z5ldf`zlxNSYir?7aSFb@ODENPVr1Q;pHx>;{g72xQf0_m7Lz)RKbOd&~|KqOIwwnE<|YpscP>Z0twt?S-Q*icvZ zt)-_!GBNhA6)_a&ZG9XpK=9Z*Hkj`0x&9Sb} z5=cBTeth$iw?88QBf&n3`k3=dP+>tf$-CmQ=%h=i(D|)$1sC3Z{LIo6G7(%6z)+^n zF&;^pn5nU{Ks3$o)pA~n08&RKj(IG$=qfxSCO2=_5C}$~fvA}CViF-jO8|+k*xQJY z57W@%b}I*(P2alg-fOTexh>qW-K>PLFl`F|$tO4`Q+xYb`3<0kg1S3oQ!UTKbbMl) z|EUGR|J+jN{~ModQX{c9go6~?o7Bqcp1rZDmq-r*|)A=VAfE>;=ez{t;{wkfo}YmRgb5D z9!WbV;E!C0Oc@Lc3KH~1B;_KWTXElYd*YB@@nvd(I`;1E$h3k9;_w6JgMz4wIDMl7 zSbR%QuHb?*&{k8hqrV;7Vo9o&Zs}h(Z0Jt$LQD1`O1*CV6@dHD4=LTVIcg`&pC2i9 zO(YX34u&c31ftM89DbvKzhm$sj3B@B*Lu2*pE2(Wc ztrh6kW|cfs{9uIzpyMSvWL@8&QFJeDXJWz$z}Kbi>yabdh$$QYE4&(RL=kf0&UuD{ z&?IYtf?~#iASVy!REnRAdJYEwb9Lvb|%lY44fN$=y4o>8#M5Td-kh;cl<78 zf6bW6qOWL|U*KWp*>iN==^NRVA3ZhYQa-SVsLq$iD`mBpGmlLi2oe(y-0s)X`Eaw= z5Snm}GNAB?LQIYuZ@D=7N184;oCYjBad+F!o%7R`qjts~KHEPEu$;z!FTnOrt4I61 z*4pc2F#e9Wn#c1MxA-^{7;n64Ezv9V!w=m1OXZ3p_huAKLD72)K<=>ZhC$eFE6%$P zt=7l6BS#T#U%EMI_I9bJLVgTWzuk8A)x1e!6U>5hWde(Ig+FmWBkO;0PTYiL8gF|+ zf&kKwqBT*3Uw?kD+hO;QLPwW$+*0ZMhl-eY#=O=8%w{{5OA=`h0nz2?=>5R!2eS z_qW(*d}OTC^vZkYL_sPNj8f-sNd9c`+>SRs9e028(@(ax9H=$WcGcd8zjljxHtR`2 z2jR7V_N4#H+}t~Bh*>ksB|~YD?p;{}e_2M^7>Qm^DsL`^ifvJA>xB!Ch>No9UMXBj z*=3kJ&1MPb#&$3rczTtV*!mx>AqAWBV#OQd-rQfw>)yXk?O$mza%8|&hYXs#t~nnF z+CUYme^R`jc_g=h5nSZ_9GFEfA#y9XV5m4eErp-&?|2=0$XG+$!fSt z%lzlKi~W}xTRdLrZ_s-EOq*rjm`)ZIvKV&|=K9KR5xP^<+jJ!svq1kqjUFK!@42x& z7WwpV->$q*#WNxK{iO~c#Fs-g;KdIpguN)<~M2l3dVt%I6W84H~4h zYODq%9fu9;S{f{RSJ4m#`6O}GoI1q?g;ZUp94GK66Cz%G^GUZK(v9;V8CjsuC|S=i zAD~B7Fujp57-}rS1Ryl{n%FoU<&GE{3K0$iv7<7FMhKKWasqRUQWKvf#Sf{w@7SZx z{g51uctVbEHhhS9HIQRRL363Y5Tsj6ehebjk3LWFduFeQUp$%qA^5vlKzT1~_KX!xn#rHExgkKl7n=T7m)JomT z=T6z2l3%3_&VYtwJAgvm;b$+O!@2k0V=E0aqS3)S<=I%zsxA8;eVX;fD%-@)H-?Q* zcAfTQSjdJHueghCWKqW%cy)WhudUbajM?1$oQZbsvqu^dMV!+^0s>}?G|b4zFgN}H z3FAu?x)X1Q-qsJ6ycJT+Q?=S2oed2${{6)lOIOzB=<@$sMz00-X=Lm%A3Gp2KAT%}I~n3?6`0(d{b;aq7_TOI$2?Nq~m_+jm(>WK}^JhR8w z7%#5wJ#^A8+(N&*`W6{ONs(Rq<~6(T7gfa|G8zs!y~u8yX^Rmd6HBkW8JD%B{oFl+ zzZ=@=#xTd;115D&K0e(4X)D_Z@4-{j+qc6KWZBzRTnvxpc+^YK0A+mPFD-4y0eDEa zA#Fc{j4RM_m*2ME6hS2#uvM0|vQQ z^ATqNu!{NS1a}{;tCerwY~gu=03G0SPA7%bNm^K3bH~h?+)-RmL5e~JAEH;A)|Ht@ z_k2%=Mw?$WIipzVKyb@+5Od#8y6VH<3ce(hG2r9pA9dz9D0&comZM_6OD=rBI4R)3 zA&!#T_g6MCBLY5fdowybIy)1WO;yIJT#p3{86hu2+c|K1RP|>vgc>XTlBJvW_Z2ZE zC28vY`_qLLL{}!7%}IyzXD)Ewv@1v0JOnZf6_HaT3xw?my0Iy~(}wGI4eZ1uEWh_i>*zPmvi zULYO-?3!?rwujrYJI#(8@~)&x3lvD?kJ>6?lfjMd+z-v_nh%u;*&OU}WJAv0XyWDk zc-7-orjCo-KXoG)^WM?+gHO9n?PAUu(SJ6sk(3D18XMMjlhq!hCBMHPQ-X>xsbC0) zMuEFG)$f(e%*=hQB;=~Z|GkXIUuaibwOxed0C?T|_n#m&26)KxK@)(`#myPQ6Rf{T z_AG#T@CMvy0>~XQsA#Z4nW71Rm<2o+PIJn91pQyR4_719-T(f*8-n>U*wR6b zX6*S;)2;j|IzMkW$6 zXmUOdwCXTGkz(mSLen+xjr%(@0ct&KitSd@7b_xvz-)B3-vcMw9IXzVt<~+!Bxx== zCFar6C8axR$6hewD|l^zf=Tm1`Ptl>F-S1Fm!7_UjDBdjXlXV=sYQU-=hmha4E^nX z3MZztYJ)gA!jiKbwzIZm5*|>rYvPM+SZ9Zzgtj7W#r?}9G>Xf(7eIHcYe{!!gTBwJu`Av%!ZO852{FLP4$xp>b zProBHC*3pkpZn!N+txN^owu1zZeh}3QE@VG>m;HGDaNKBbE<^y@NOR!mYhqnF>j=y zOyg!C{jIkcMlFoyhCHKMCGLT1F3KKAu6Wm1r+fF*SGZ8iju#!ny^vkI7P$ICt+hoa z+=sO~cVYL{<$nNuFHJ2?p!w&LgJEkzW$fMlXy}Hik3w3x>^eNLo#g% zz~mj(0FSp?q2@xT+{DM83iXfE=)Z!7blLi-ML#l{0f2|T42H;g3%qr$MCwh>_(iTx zQ)4dc&39R`VgNo7PD|dOK5aBiO@D=@3r3SWzq>k##u5EZU*7xT&CdryU&?%NpPK0B zHU8({=K1vBgy?bMNJDnBh3{ONw$^LsBY!_RFPBv}(=f<8UQ;i%d}DmJNdj{toJSgx zLt7a4(s;aT{r4daC6%h+NxNRUH)Eqj8{^V-onN$KdkTu)r7-DV8L+bA>Dv$MACkXa zEA-Qd&qRrn^UhYkINRQT-?P|3oJuwTdhd>wqn>gt%rFyE7vv9-+>yJk=UUpirRP;r zWy^FS0!1dRoi`a&s5*A~%L3w#=$P;jDF0*F*N$M51=s?A9e&_|*hrE&_D^=NyuV_0 ztMzE9qF*<{AxPO6QH{h59D;K z+|mYt!(I1ukY%Tg<@7GGf4QPB3N?I~pe;kqa#7Zbd>kA|!1{&n1_JC-x@Qhn3n9o; z!2M(NDdRo6?zp&HX`1c1BlLgIe$3QVQ=$r{nYJq-5Y)wY;rM^|Xg|6}h@57xXiQK>Nypnqd8G{sc6 zuhGGZ=Z|!`N#E7G4iO^#1I8Z-KA$ZjR+7Nj%cUM`W=_3)YlR^x<~fhyOYl6>RRaQ=;Q)#T>xbE3O^7WmaFtyNm%Y)|`lzVWHk%!Q|aY8O*> zr8L@d@#Xe7C64B-&S10zjTJrd%S(>PZ*jQK%mqKt8xW{X|HP4L$;Q$jgVitulMF`X z^cHJ$D7G>bTXI-y*36|4>JqZ@huTWd+Hhar#x+RZ|6pKi-x(J)q{|G6F3f+RcFmhT zwLL1@=bgNsZhPd2&R7#O19D+o)V)hr_PFq^G4)+ZB>Gin#Pd9{4q6C=gv?sEveeBb!upZt+jvxEr{%OY> z+_HAo_ha(4UmIW8Ry1TlK>k4gb~oC5KH^>}axWbfn!CkSlXwaYIP)o^Gh?zd-Ui#2 z<$u*;4d9nF+nqX7T~4K@4g*I}sfWCOHB96xvXIMJAydo1WFxQHH)OIS;$)d%C(;gL zrO0hz4mnEh%TO-yuOXJLMnM1E`_#F9z#|!c>ux@n*zA)hdOdrN5!#$g3=bObrq#mv zaNK1{c}2w}35dwv;DuYXY#9!3cDvA^mUmZ&!2?Zo?$dY}SwO+SquPn?FR=q{S2afS z{`d;g-g$^|G;_RPEK6vAy14O^eS>a%O^$*_)@sh2IYIV{hug7Z%Mgt46s9^2X1#~C zbs3{d#^3f=kGGE+MtL@ma&oaqpNrRG!Slsl^M5kF(T3%CtjMi>7czy!i0dw&4h(Uo z1b<}G;`Dymc3jhT-hq5Pc*p~qnPm1Hihwo3>0B#;t(wBrUG2iJJ>UL(KY3Tb1=|u| zN{>CquM8I-{k;kUX|z5tH)i&p+tWekq^2hM0v>RpxU)#Y%gqE&gSvH{`f~VT?0o>Z#JdAv zcZc22n(f=W=gczAnn;A=4vWZu@Fgd}B?ET0i@8dbPK_%KI;8kr5qFEBJN-_Gmk&4H z^euIfpjM)2z@`=f=juXo$4nxXIInV(vmd=R-%p%ShY1ElSP)H)RIol>ew8vX*6EHF z{6(C;!6z%ecy0XJc(3vJ%Hq?nhlgw!Z05SBEvE=8=!!<5FC~KzEgL8KxJ;@Q%seMSsl3!tEq2=**ChCw&ztRD`@WW2#gr>C` zygYgdQyE$h88T!&rF1%$lKFxdZ6b3}>hSRVpubdWQl4FD;NL=os z52cS-H8_J5bK|6HBxcR0pRS2Ybz(E?sGmzNFL&T3xRhWANd?r@Z6ZlTe26=G;u? z<2`s)R@Rpk)=#qgR{iD&UVCZlIMNV6J7t+KvG2yezx=Wf>Kw<<8Z>AlWuLFqc&S>UID5IY+0ZyHksCGcV;k5`Cs3IAgnkKNrufid5oLyb`fLpjj^?APe-+irp};Gp&qw%AG`^+C z*~fK5bS2%s-3Iw~`;G=a9!ZP#4Ftz;?$hpu$K#cW`Sr>a`I^R8_J30=Ykc*e-~0cS d#2acn1ER~G8DwmFAa_e;Hs;48Kg?MBzW~2j@HhYf literal 0 HcmV?d00001 diff --git a/docs/docs/assets/images/llama3/lingua-8b-vs-70b-mxfp8-uplift.png b/docs/docs/assets/images/llama3/lingua-8b-vs-70b-mxfp8-uplift.png new file mode 100644 index 0000000000000000000000000000000000000000..8bdb78070e4fc320948629734bf0eef918071d4f GIT binary patch literal 140587 zcmeGEhd-D9|2~XARZ1EtG9pq6S&2kOM2eQ39VMe-%O<5Hqs&N@kxh}kiYPm>8b+BJ zWoD-9c)UKJ-}U?c{)FG{y4-H>@$$m+JRj$M9LI4#?&tmUXO*|mve8l~lr5)DDymZ` z8)zt$b!;^2@f{sAfd}{>DF>ws4jR^19B!J}nNiM|IM`gXcDQC~%H?EcXK!hJ-BokB&SoKlq6 zbdDMCzNyLB`A=eUzSWI^@e%ihR~O~0g;FlcU;5qst|#zB!hJ!tguCt&HQP!~1eP7} z^Ybr@H`N*U=H*sTY@FH9{43b(ib-VJwCB@)NjvkQrl(HkPO|pD!#DbHdHemJU;gVd zD}^Ww|K}GI`+A!D|35!N|NnmD|6%7b{lS}REmvr5*lJ^pWbVrcT8T#_CGQ8lP1Q)I zEx9lo9wll2Vnmuvvwcdcn|I`W-v8yQ=vee03!l=PreCc77^wsMxY;Q>e6Z zO5e4vm-4Uw$;ilTD!HIZDIZ#v@%T#H|M|OW8RwMLU)QNs6~$S{i{20Y{Ml2Pt0n)w z-(bM~e}64~5jl2JQE@hWS-oO`hkI^n(_(gSL0irv|MjuCisVb_$ypQT{{nlu=EI{s zN|slyT=6|0BCPsv_3zZL?xRU!M~{}Q(Ul0^$j}m5+ud!bsoa^all@TTxT3oA(`^h4 zgTe^$ln;+w)dUJGKka*0Q^TKBR9N`TWo9U3CQ{f~wz*zeF-)UALh!&ixBH_jjnP^b zV?zgbb1#W5E-bux|33NU19?T=K1UTr#jpw%uD$Wx`_q}L?``GSMolx;ky+1N`rn2A z_g}M(eJpnr6{-7L5|mjEUD;e0DNJ|#`0)!DE0Q)+c}bxon+fq}hjJBQj* zHR#t-=@;BkO4rW3c=hTg4h|0d=*gj}=|R<%)#cxlfBxKk`jjIeAmE~f1V^)IjAzPj{2N8m}?tG^+kp|4D< zX~k}QW2t_?eDBsRs=5fl2j8pi^-uM`=bvf29P?_tW_`%Hn@&y<21aS)x*VIdOifK~ z`rc4A#7J(R_>y;gZnXXT$Ow&whQ`R~XguyYMKg6<@y*{sfjcF*g@lCaS8}A3gLcXI zH$6RC)zox$qJJIb?c28(axQONznR6m*l{f0eR29_e*WF8ETPk9&V0wET3cH;JU+M< z7vAd3dHJ=+vXk0kSy@@G!3(W)?Rc}*zw^dPF59+mf8{ix=P=r~k?DZ;ov5g&A76y@ z^94pbG9Rq|UA{E+qq@-ZZ}EVuSdsRPp3?hqo=Y3Iu!(sG25#BBWy|yI>|pgImACEf zOqu!xHHoJ}U*7y}RO zVh78Sv9YSUI@-o)@v4e{tDEHMHE$?8nS~{v@+i%*?LX0ybN1}n=<+GkP0Yd@!L#Sckqep9pITdAdj544yZ(hyTU#4H zq1kr(=I@?HHObJBke3!M+-=#$oI7^x5IKDKM3PFli0zMe*M4iMLYwpQ~4Ym9L{H4I2a~uylczO!|(du-hO2_SjQeN+3#W6UlSPr6G63&+VhRj}NuUtRctsT(Vplsh zcXV{%&J?3W%_ghp25)n9X-q4=C^)IKBpYWCVuhCl^1^`*N~q1Si*W&X;~Qw zhd+Xl44JQIvZs8OyR=z#aHImyuc(a>7vVD<#`}W&EKP4wvwWpEC zaeo_6-^4`bk?esVKYmPh8hgEMY}|s{C35J{@l3ru5wBJEy?ghjXa%8InXW8M+1$K& z508I#b{1E$-rU@rgz*9Gj5o>W;@avC%*>oY_At63EtMALdEAFMIAIQkee4d=V z;p)|^npp;{Po*3yrw8k&Ru=7BawKkidy}S>F5)n9UhG;IJ@QU7=3Grp4X;)jqqL07 zLCe;?-%4E$HRl@(9X!Y|)|su096n#}?t(gZ?A$pz9@PlN(0yknzgk!9Mm~IXr%emu z=3d%Iqo|lj(u&HvwE|(4o_7uPUa7w6@xkK@xGL(kYuC1jEy|%Te4iO^wu&5{{9ff# zP*A|z{q389<>$o2y(T~2-S-I!3aWFg|MwG*vHyLrj_y`EI=U87T!v0b!>wDlLS}w- zy*f_vN;eD9MbjK77a8+s_NPW($Hed7!Yc?!-5dV`SP$*3pT~D=7vp5y_GxD7X_$$>+k6)>JH*IWQ zo=?+E706afT{=Op*;|3a8#idLn{XoV&d7*Gt#r~*q+SIcLCMe2*xcx75fg<7mzzIZ zm5z9>6n+_2k?;9EsQGjcVy`hLsCJb7^Q%kuE?&Nz_$R+DP3yotWv<`z`tog=c4rsqo_7l$mp*uIEmjVz|=q zcsKfmd5JsvGqs#tg~7xg`43;0`VhqNy?kk_fgDG*U%R8zh<~}acKzn_OJ_nK?VER! zr@w=&+=yKN;PjLPpbzyX%^{74j*JW~1>RH&-0_>IrH)-dcIE9-F>j1ca&pn6W)NLZ zgKlwWhoGLt%AQof9tCyvkQDYs2Cmte8I3VBr6Rk*y`S2*tVwa(eDuwg1jUi7JT-I|j9s zy38ofoQ^r7CR%|?;5Wn}ZWU(gYGsvn{<#8mV13*tipkj^bNTf&2yg)lOG`8T7vDOz zjIO6?mytjv;VqK+9Ip21lG=GH1&RWikZ)U?rkW8SA0PRLTtg$IMtho8{U6tDcV?dj zr(KGkii?en9jNbVpsGY11R8CnGS>D!|I{rqJUo1I^{<;+RHD&`Z~MkGcrITCXD*DANY?Q)bQM+LM<+WwVV(PNH zMwzQ)twZYA-(TG&Y@BrH1hrfyMjL2mMqP682gzY9|Ctc4JFcp#+Ewo17BsJ6)Ug2# zxRp1ql!kr{z3ZQGKHXhs*4*)0T}-YS3-_XVs3za3rl{DvCP@CIVp#AlS$7(Wjg8H* znW4t1?^PSvrJc{Wc{HGtia3sHj*MJ37L;@Qvzfx!{S_E$w}8Orf*alH4Kww^6>plG z&%MsLeS0mAhBoT--8XNJsTFWjQBge>HKTR+@DT6OKC^e6+3QX z7BV;=ynEG!b0JW>3l9K_D2lAD`bKeYrD z_kDDf_P~Jy7j$*szI(TBYHG@PX_imOpwN3@;Bs{-{+V1Tz-O9v&vcflom&jQb_M~~ z_=V?yswdB!QBqV~CuCIW_u<3IB$t0nW5Sy`W!blG-D(wnFk$!cyM6P$UOJ;^9TAX^ zW!w(W&dr4nb?0AyVzCH-&$sHey5b`3JbBEr7~qvn;>PCv`}ZTzzjv41L@|;5)mssl z{wtYs(ur?)WW*0JtFobCGr7z^#x@I;|PFOBO{}at*w={q5H{&ZiJ z&eH>@vPS(nvyInMz%^Qm9j~&mupoTbE=+#E-Rh2BMze9Fe4bT%PU8&V9i=tc3xJt{ znRx?>iK4Rdfks?wK$2>t>9;~#Ng*%o>9Lr`ae&yzM{j2F8tUus?CI%2vYnTf8yvX6 z)b`hi$ptLS{xXtyVP!G5)qQAa2+);a|B#RndIpADA3vV2w)Zlso$x%6QY?Pdy0_f- z{{4-p54T^x7GsxnXF}j$KENl8gre0j~s&cVUJ%6fnM{v!m^sR##W8GWRFbPP=>pf1AD;{0)rB zW@*+^&Glz%Qut~tB!rX)dXyCPBwDMs6e_O2MSl?`{+9S+Iby^0?b|myO`Rbv4|;lf zZ}3Xhv-VC-C%}qKOig_d=rM9$yKvWyIR;HlO&}kyyany+sjFD+(0ua!y9DP4UYZpZUrA{&P9vBpW+8v#2qf?%~U3ph^l7tlH zyjBGq#yZZ3#g&(rPmZ=9NGhEvrv5JXZ)H!BA*t(OLnU^D{uiyR_$9N8Gw)FRin&-1 zn%=S4g-WgFnO2^d_xd%Xf`UTWkYS!x_(18^jtsSEVPSpxD>_Z3cK=pAc8iFd|0&%x z{`U3;A5`AfCH%r2v`W9B#fAAvRKJ)U8D9sY5>$MS8 zK3p`FompFmrG@p*VFA1_#&KbToLgt7ICx@It{(&>Y{o0*NmXjFV@tC1!+=o znVFfsfD3#>M{WAN2iyVnRk{z}^YuNIzw)iZtD@EQ=EN7CDB~fasNPbSJ(?L8)WnQ| zfWn9L%Uts=dltEr>y|pd8jxHZsNJ5lNV5HNNNaDeai+&p4VSsnsIS8bDC;z!*zwZE z4nN~HIq%)O_u$v{FJHb?qWOePtDZf(;{-ZrT{aTMDv~duG>s1Fsc6>v>C^U(Y-8b3 zp(A*9UAgALBP-k2$#)f4$3YH!kyDzQni}%;Uw^Qi|BzMzwFzpB(##f4Sp%_^+Kx

q`4httczb=3UwZfa;!@oO};$g)fg6vdqq!V9x9XDZ^i=tSLriuK6xf zl9DE7Dy*c_W>iZ3sh-63qW88Ty1j-XK~%Z9D&$4pJ*x&I;>k2p_q zK?+D)y=ygs!W}QBqodQ(IgOfKnWC03`JU5jz}YY>D{J9ra$Hy|O7lH{Q{E&cW##^= zdvq^fzWi{w*|xt1%8wS`^$!ns163VevvzHGD;LhRZXB*&6M+0K3sh$DWo4< zxWFRov2^F%JH_hANYGssW3}X@r0-wyuA*pK=8t`tuqC%g;Vt*~SeV=$pOC=7$(e2O za~mOV)-@f6qS=l*320{-q|OfhL>CJ$H|+)|8L4mM^jwhtX>D(BuhwJTQ~JmtIVy^k zl>WolzYG+9i|S^Qu#9z3fLQGKaHHGr7x_URD z0ARpzC(#4Q+>LMEP@rzqK+X!+6Ht9roMlwuDd^%3@bM8P)nVd`c)j$dRw$z% z6Hjfg?ggr25woBdGAvFMi!0TNmRenzIj8nOldH^#VI36|ifw0hN;``{A$u0q@l1a0 zW>(R45JI(8ddu7-B7_XnoLs~lXNJzyc5JsFZc+lbV)7ar8M#QVG1r_94=k*I4!1!E zsZ4#v@Rh03>l@u{|5lf?Tir)@t;LxOTzE-AYWR!c)&vZTHQPj27iWg?z!y!&LdR)* zxOC&_*u+jc$Zo+$Bcb2m-?rD4i2RyQ8<=IUHIPe6O4`lMt&l@~1FZ35h1WlWNb&4H zUnA8E&PIt+Bkr|Ux^v7_Ev{aD;HqoUnZ==3a3gb}1Q5pJ<1;EBAD^(H`Y!Ix6sKR` z`1H)7q$*WtsYD3e2WxwN>3xvOPr#{~zsvJ#xOmDN^a-WK!9R;vfFc{QOy$ zud%6VD0zSbG4;j&m7rExA+`c^@KM0|XFSgU^iwfpE3{O#|7 zk94doH&QsH9OwGA2%6#xkfi)f1Ux`j=QZ{I1rBKyuZEa7UIz>gPjmVSCyu>N(>Axzo}N~BwdK-*ey8Sr zezZCxtlVSS-{&@~VX3oPSVxN5o?1>;A|jz^C>b^w6x!@IUVBGf?jm3x2+BT{--y1f zHFuz2FeRPh*v|ErC1pF;cw1`dWmiO$B9%hw^}k^=FP&w!T}J0x)mRaw>q;SYaUMh*c!tBnH;xYyLgcTP|>oOa6`t31z8oy zr}5(A`yh_^;FCEaE<2Ng9-%y6?yP`FV9^p!oqwY{dQd8HI>Vqy6j!rHrQp>og1av> zC$qY^xXAAd)aNw}4GkSeKFWs~L2!0h}h$;+C!Jpj^ji@nsv9{ApiLgZjb!Wmn%nAI+$_k(T2V6cdwgzmo=r z*ezS)ebsNSo;r1k@J5`}2FOQYg8Dm9#qT_MvJ>s&dZD^ZG6?3z* zMDT9ez{h&~ww-##eNoOmsl6J8#SV`IRFN)F`-O`Oj0$b!8Mqq8hMz!aXxXrbGTD8z z*R06k1a3#mttW+8&fPT(Tb!PvxXfGvXckqmgP1i6s$2PVPR;}x zY@z*d5VxSDBrEhj#z43JzCLf1p-T{3si)WKjk>lBxYpAxss%H@hO+UDoXUm;swr`9 z$AIyP39a&D8pB5{?Y^})=9*UfKbCUjA8Lq72wNpua+f5DR@B5RO)*izf-H6l-6ZO= z@PoQWRqD-Ky+`|>jrOp!vPM}`DQFEF)$HJ69elwKIHr<$chJs$bB> zx6?eL3=7<;-!?Z3zke%FvGPia{nQ^89qo@tF+7sQ{LFJvb8328fk4I9!`3|k3k%M* zS#g?<5=Sq8ppQ9vb1O18^qL8GB!vDw%0bUicjMnA7Gr#1^Rtzw;fcgK%|8uBrOvm= zL__Jpjv`NnxxBhy*V1QaVx=XO-@eW69{-$c3sMgyY+pTd(Ea=O7k*d#JCE_A@Qe1F zM!h9=)8mwsa*5%&+~a3|-=tm~r?|UxmbI;}Y>Va4tZkg)A7A#kQY52HF3;=VUuFn- zy7^)EA?tuq-3F3Ub^0CibfwGMSfQKWQRtLMN=T%mqUTC%9iO9qo*7Stn+$Kpa3f@Z zc+3C`-R85@UBOid7pWEEy;a$rOs&Ov*?i!`BT3q=Yb3KT&0MPR+DN3*&O3r|L6nPr zmQnoL)#cZkSTkO0rH6q{*lNU*g47JlmSan*sVKk#^&1j{e}*qA2c~gK>V}uSl!hig zeJ`u+@MX_iiNDI%wZ7Nny@r6;k;Eiy#10Kk0g{@@_3N3=tlAZx9!Upx(^jH~?0x^1 zMb1)+E&fo-A2Ok4c=?bVTxMWZ}py+S8bbV!_7dzkzO7A0mY) zTtS?2S#5I{3=G(z&fNvRPJjBZlY4*j)?-UDM`jM)k~UPSahX~2jI<3)BBEN@YsTz0E_MaGx61b z*#_;COp?JUt5^PaY8N4FoYmqweYL`4%rQMMFmOaq-7%Kdhs(vd&N*v%ZYB`Y@WA5- z08#hu-#?XK&m89C*>n2j$zyl~wsWIeNwauHRp3x&-Ru0k{;lyYDLFdYmYVQq$^MI$ z)s3;Wscz1I|3CWsdGu9P0L%g}-;DEEntgj`(>#0DO zeRKVe05w+OEsc$5UatVg@Vx)}f&KF0HatpY(PPg4!Tl#sp0xIk{;EHqlf{KGR_1id zlAesXxOitKN1~C@>1eT2VnIEK1RFRN3^EVwPps2PQej&8?e%Yal4@Y{BWBn6Uj#hx z=6!igj2}uLp>JkudjbVRz413?KTd&4b%?NxHZ64=H z0f{SGmZS3xx_N&jm)nc6K>$=E|Do{Wp7JZ|omNJ#zXbUEhm8NRM2f^y?nrf>9Fp2{ z@?ZSl*`J@{ZE1+Q!q%_dEe$9&1)WJv=gFf-k3#x|%#<-E6t`+S&-Q78(6xdBou!vN z5_5{4992GKMA-G^%MYM>m7j>hFn39QCWrjIJgrfF^-)huq}@Pi_#E9L(59ary?I6d zMC!}{#AJvQConxy98(+qcHrdxvoR0dcN}&5lQH1zq25dvrGKi_c}n4@5mm}Rf%OBu zl4IlJn*q3zn7PP1$xS8I6O`V zVb2M)?96&x|^DroAE1U^jMJp(o9}|@6$SlJ_rmp(kyOU>^A+Od9K0Rc< z6|(65QCGe_sU7Rp$EG2+p5CXi6uCI)zfZJN=GfAM0!%3mB>CRGYx-yWv!qdTgo5o@cH0%}~H*ZoQ;L8$Y4#qO?a_6b| z%Qm}brXhI7M|eOdg&pR8YcHz4cDjB8JO!tMjvjWMwIF_!shcV)v_v_{krfmGn0PE^ z8Pck{2z+BQJJJ&G0vC~5rk;3`Sn4p5T8Q>GHIEjXtTz8?(XpxSvS{tnx6V!-&G`)7 z+!HzNDp!mv|H@`VX`!ItRDr>T1sR_h?YWm5+PR!5GTKqe#q>-HfyjG-k(HItWlKqM zaXe6&f5lAy;tObm@vq?@Qt@2j!zIWpAX>n5;*aSW)z=R+GFfzO22k&5N_rP=(UiD4^edE=l`jo9J;m5)J1iE@ZHsP|rqaPdCA>IQ9mc2e z5@U6>F?)G>8^DeXbkIlpv~-Z&-vN45EFo+oxaB)`Cfx#m@&9!b3dBg)tAx6#jVLVv z-rke|%yefNG}KZP{QRc3>Hk-RuU)@>@Ja*A%HnXG%1<0V< z_ymZ2SStn#cL9U=l4LMifUu+h!xlXUM<9M6&7Gz>1AO7};p@>Br!%16Q2kq(p@Z(} z4~X2iFx78WMk8d=($Ye#EPCMFkg0zlk3xhC@bN7rH?5Oip9o}=hiYY6qOPHV+}ZTn zx_4J`aWNTR>j}3H5?e>S;>H$H%o5-U(+-d^v$hU~&Z2u;kQ(LV_Sy|wd6WLa7=f%i zY}J_+h>jQ|X}|CNrE^2tFl8xOlFx@+MCNPoP8gIp3HIts1c9W2cI!yi1P2GxL1^HU zB+&k-)A$=qM=>ElB_?hid!^)~3@?HZYPt233W) z%3zlv{TUiyJagf(rjj$QieP{TuYVCs3Z|o69QpYCLWQS{fuSM#^(~YnFMGLxLtd;~ z6E{qK>+V*=^t1XNCwST|xZA!gLJ^I>_Bltzux-t;Cx@?p#MuzP6d+Iy0Ebo(UF0KF zP|uYqBT;djwT_lVhO^VqMQ4| z!XOUv+YdB8bhG!!x3(LC*-1&h8@KQCGpgw=v|Wq&oiZGW7? z>8(YWc-$rqJJf}j=#q;odk0mEdWU^62_*iw3#HBn;ng4>q9aF-R$(>?-cAFj=jNV) zM9;dhzuGL0S51&td+@B31UsNz2clqC!P|sF1gIH$P3FkWUj|8}G+aRkA!STm^*zFo zTqN9oWF`bZ5#-Qdpug|%0P?RDT71qGCjrBPFXWB;t!NPfw3SyhvyEhML>rDR{mutJ zG3hCFQJDcXp~Jsd#7vm6c^ZLL1t2$tuB*h+S>fdsSifb!(Bmb9PG=VvKI*k3^}#_+ zIx=+slU`1rhAZPfLg5Fln&qfdhG`GNp{lFv!q3~yn69;D8SWI+&%aYtB&l%xxGheE z{kpj}T3s;W`9RX*;v$@iRDf-FFayW9nT*e+U1!M%#eaS)tD6wT|15`+Vp#GzZf

    2@Bf`hWa3)@s}~skuTg! zYM!;VwF=-fxrRNML=#OZy#Y1OCqLgQ-vy;X<%OKbl9QiZqANVDWAr|He79de= z7A8$<9sj|*WW_-Tn3mVf=Qy2VR8~BYqq~ugj=7K)>-+cJdc&Z`iZiJibhBff=LVPP*pEE3lhhj4b^4dsMD6$|QqwT0sfY#b zKrgP?F=`*NiTsbD1HCFWugvG!s$_EQpOh!b#bzr&5qh0pDt$MJe~H7q39h#DV@JH^ zyKjd4j7q|w2AO*$(MN?29>0s1i+S@66Xyms~v;Mqwj^Owr^Q zFK9tCtp@oV&xJ_Mi+wg8?cq!u)kWon1ZJlH3nLL)7%bwYIc42>;m{ejFiw+7wzF%p zGKCOg#c?3jTqf1NH7l2ZGk<^%UF>v3R8liDO69_n~%Ff=AcU}0fl z;FQs5UH#HU^Z~ZPL-6+{4emb^o!DWg-u$4*V^`GCn>T{~xDCzisF~penGG@w0G+t) z;OIyO0K|Ng`HCLyDG|>VXLd zG)<@Q@S;tReFosv81$Y4=p%T`-Q9g}^77J>N>I{X^NSa~!TX+{d2-k)(h|ZAzkbT# zU0_qEi7!bxJ2kkvB{>UON6(>yo5y9Wt^Ryc_Ldv`Lt?*#|)lvXkRF?w}rx_V9 zUAnX~*J;caM>onM>wW~&j@UtI=)2Vju=cuhn7ET^XW(*#ODt5j1mg-%t3gW)mT}vf zX_}`7rBUQBlsIL^^SEn^$uv7UI1u+?>NxEM4gj_qm{#mdS0dkh{koH=fTL`}2T+aQ z!vpaV%6;umP`z5Gm3DxR2YO#q;!(Cyfy`aDj{N*-0>@TJcNFtb?s1gfm+(=mNldbe z!xxDw3Sr;*5`ldIr8Q#M3BzvN$v*kK`mo;NaqEh|`y|U9$OJreJnQV}1rNUY6)oNM3mIcy<#Q2mIC(V{q*V66ZICOs$M-|m`2^kHD6kq{i)m09;K_VZ|)Lk z#v?p-8yWNC{Ld{3UBpSCv{J!tD2XgMM4FX4eq1;H9;H?Q0MmW^3nhTYB`vhSz2xcqyTfg2#*Df zrwyVNxIv;*0KHKWE)UrdMUg@}AP!LknD_DH#}=A3h^FVTr%leyRpT8ML;Cu)?qa5!zrMW%3x>!30sIoCP%CwV95MO2 zb;!KwBEqxl)uj!)FrYvSICA93GekM08TOc`PsP~;x_f$d5t}QFFTTJl{lFcB7g;pi zj{z1t0Ut7%Q?Va)+Klo0okx!t@M|qa_T~Y*WY}Q%EiyWHHtNZfezY>u!y(1{!{D)Z z|9&`9cAy#&C%9EdMpXK%sWvreh&UTAUfz>@WoQ^7!fu1034sX2-CSH0cp7Rju$hEf zO6DuDg?{|>={D*l0^;8KEo{$l$Tp6Sjkc0HSqAqpDER@VRp>OqM&uxHs6B@d?*O%c zrH5=40D^t*@4t%EJsmsfvH-Mq>uxk*jHc)Zl zTiAH`@L^bH52E}8A$qTB$~}Qs4i5<^wG!X3#UK*o#z2I~ z1ku5H$+5Aq2S`#4v1YpV7Py`Y z?FMOKzTN}Gk{uJ2)Ao@hmn$wsy|n2!1O;%jM}RHC+k6m! z+!zrfc~GU-!vAJsX6B1?sYJoVuQg8JmO8eySo7fn4dye7pw9>G22KXX6~PWb>^=Z* z7txIhZN9JBCFdCccMBIRelK3UKry(30>&n8wGQaK_LuL>m(IyjK1}I1V2)Nf(vt8B z1TaR*krpuDdHxp%6gZR1m+!zqNJeKVh9!r=-2C^Qi5^%e63ZSV6Z1Z}j^J3}gq_z1 zF$()GJCrtThOnxH@jm}>V5aLafCBl6{A-_mNG-*a^9v3pPW;i%>|j`Hsy=L_` z-kzk&><-8RGY~Vl(Ou{b{B>E`+1Vq9D-g-V!3F#Ih=he&I^mZ-r$0Q9xQA3(yo{+L zhCn&$dSsm2KRIykCtr)95Hhc5f?rd6KHvSen}qt%MOzKcXIhMzwqCid}TKb-XgpWBqU zx&IdGXe$dkCAKp#crby|{(tBdvP>CziHWr}9j1h^o_~kdip2_Fi=u5jDkAoj5vEK|kM8GO29AT-Hy+%USYlvU z)9X@#KW>O@aN+SBKX#1L6ekzF|6DAWloXo}*BkIaQ~;Q1)+1W);+kIx2`NGaB^O1K z8XOg-R|-6PN)RClfyC?;4=5g>1azn{1=o;{-^2JJ%70gTEhSFMSc5k2CR?& z-!G=!L?!<}zW)F28)}b+CMLF6TmP;UT2Hx%B0%~FwAoE1yqX`Jh3J_A05-^?4t6RO z0JBY)_Z!2^%Z{E$gsj-XB23oE{2Y=4n{bK!@LBu=c)xIS-SMzqh>v6`?!Vod>G4m= zbXZnG1B*3xxss3ZE`F3uNEY*hgsTDdzE4a9VycF1Ppm4;wtwUsUqen??fVFR6Vasj z6L25h8#0vmM?e=~Aip8FJ-4I#-GS`!9Y7k{BJb`l4flc?rX2Wy1j|n^V9Vt1{5R6F z=UQ($2kJ6_(RSE-&J*hXs~dc62Y;mMsZ(#Eg^)M_s5p*#4)H+=O}C}cwx4}sllpaX zIP>$&NigfYL7yXi1Ky0jp`m?6rD-z)2>=uq!QlNXdMmtwQFwmkz*_1$M+?gU%+8O0tX%u}?3Ui5JD=z-vcyCHp`KoAuiJ~Hr00lP zzkq;s4_9};gUrkL>0Hc_TL6w#=wI+>%3~_Hcc#dFQS`;xr?jY9D}WQ=^!R@+_?eH}k3Jo2dGWxU945+|cSaN|Pl>lMMP^h*P!2CWzkB9(rNZ8Tr8yHxJ zwF9!HCnY7hZ0R=Mbx;~hhNxFEnCRJsZTK5UT;I?fU}Z;?gv1P&{4TG8Wb(?SYCYh? zF=OKjZ`CL_0y%IwREcp_|NhQ%1_TCbHObxhb_6wU14VJJ3oTm&J3>&rU}|Dy;@5oG z(2r`(B5Ai5>!@TEP3dX+=$E={=;$i%6-WDdai^qs09})Te40+SP*f+#kqsydNpxsE z1bLQ~$>7HQfXd%~{3!Bu`k@un5$BbKgZPE4X_l0JgbjfWkOX`Ie)!KE^;|iE;%0K` z(ru_Lm!pDLWMyrCD89b-S^ZJ$Cs#wO?&+&_EiDY-E{#rKv5VRQZy7dr$%5SE)Ra8* z5H*BW42Ga2YXEw38}Iq~rKDWKela%X#Zvop)GnEuC)0e`=;KYW1v>scI?lUsahzD} z^_6lQIPY{~nC5#+Y9Ss=Py+S+;UU9%S=)yOH>X(4TeP%!f>H31&m|2%XsQbdkM0JiEaaWo2bUK0$p=qS>lq>*!d0v)AJuK=GT- zPG*n}`fdqq?0g6C4GUw$oKXK$b-C+jd%K9WB(eRi{wbK`IF=gJ+i$)vLj76=NMPuq{Xppd;rK z=RvNp&+-OCU-F#8!~#9MEkK zU}KQ!lc0cr2xb+9MMZqnfdEzC@w70LfN^INu{WZ_zX-djsmZuQ;DR?!hSXo+abLs_ ze4HZYP3#8c?yOi)t3mHq6r}IW(7g>cDgj)F#0mEsw_o2R8)9X4Lj6A$&aZVIi_N7y zXOxVPX<>TLsg5Zqtm()yB!lxykftqK6Ytj@3u$wC;Sb6S{fMl1!Jy0lJXnbvYS%jq zTZ7!++3h_j+F6Dr@jzW77=xV-#{D4>+Umws2jCtee4k9)z0a?j>dIz^5vlj%GxwmfiL@b3juRb-DA?p`fb=Lo zGdXVi_~2a!A}EINdg`eOWAL#BbV31>Z0y&|&HnB+POyW(_yf0v&Avv7?jNaY2`4PD zN*_D)`T6mC2Gm+*rKK;SBM_>HoYpC>a6dE9`8X~v4sk5i*-xM~Of%9RON^L>F+Hs( z;}ftd_$&fp8A}y7H*{5%1H)!@S_O}l1qRqeDp6fv&s-0=S)U=J$Z;%!`$!i|=(FRW zH5k6R&W)Bz zcHk=@IC*+&v_ zP}o)GYS8Sd5{>&&b5uCDkmSl*M)|ZVHFEK2MH+r%jI(=#J;)n3Zut58Z)=}3-i|6m zW@1EB0@hagL@j>|h#3NG%WF%G@I@9jHc~+;-k=bq&H7uywxZ!sg$ZbvK#6GMCWt}= zcQXo7um{EXX`u5G8d(e2c~XS+U|o0>92j(5=AC;6Gr{(-g`$w&o0Hb51s^1}_iV5P zd4og~R;4zOeD-$GDVn|>(K!V$oybOpny4#0{QR^??e~xmFz(nWC@AtWb87+~k=%S>$!UEtA7NP(d7EO%^WI!j=D?FYxIWQGmeVtKS z{{gmnEN`pEgQ5Gqh>DJmKNY&~i0dpb@jvUWiN*F09?i?#+~a_wiuI6$cHpzCky2Aj zb63OowNx^5>fV=;x=Y?d09!|8Z!ahMD%rb7;7kIerL!jsq7PvA_$@#PldzvEgm1vx z=P;@8zZZA=w<>e^#wlRURj*MYZf+Q6KD%89T746ckiZxn~vm!cDB(LOfe|YoywymZBS@lnm6x?k@e7P zoBj=f4NO%L$C1OnmXQL;kKTF*P{#yr=!J=b-iHq#J{!?{4ZO|&eqF;=gfck;c#W{v zk$UpnfA8@Sx<|#EH3Q35V8{GW<4)oZgLcBN-#PF{b`swh?|QWsg|-?)l4p3K1vX8| z65k|T)RK2EYv-g%U~zAamzP(|5tHeua?xO=DBBeg{4=SGBxG<)52K@LaTlZ(nyk40(--FP0BH|IX>ux~!8Y zN)@%5-5zX+rsNns+(3_75d>1~|3bpzBZUn9 zF{C|c`p(b5dIM6i087UxwM$nNPngH$(T`0vLYh-n~oUN|?b`LCg@+gJOD!ES|Quwqk>Q9Xu0QLhJ{6Q2<~M z%AY{~K8E83IP`yCXo8+dh3W2lU?3ZGc=B2q{`b4;?jy{oyiJX*h!ug-`aZFi7c<5d z^;&Gf*pZ1CyNoL%5$#nX|7BW9q9Igo%12aL84Mx%NWq8d;SC0WJm2(lDXfY;~tq8eu{CSI3Nib;S`(+mpdm6eqtlkZ=fVf$ZmCN8ai zpf;3d%N8ZX9`WaJ^=q5$rSaj-#N?(BmjJ0{H;s?QzxuK0@#2K~aX?^7>sZS8npDXm zI+8;Ws$0m>q~b6)X(JK}C7hp%wV@^=rCT1hL5Asrwfv zi=0A+p?Aguevt>c^7pUSdYZxcBAl_Rrv=drA;sbMZ*@b}WH7t^znL}ut!O*OwcZ#! z5ROic;x+6Bz;ssM^vq#X)2k1L3O{LAwOM>{ zU8`T(Y!GeY<*Qd}R#ZEXL_jKi04NZhBjcCFiXgZBKu+;$yx0r@_c@J9x%=Wa0%s8` zxVQamM=eoai8zk6c0blbS9t%B$0q{w4O|K*a{@?*<3V49)2|xCZgd=Pl<@;bHOvuk zPf{q*Zba;dPOk&qIYYO3^E+%p9t3j;7gRWuj^ztDtNvT)#%G0O>I~j03@|JjRQA}5 zag!t*S|=E2;JQK>SRFvf6U`OCrk$3aEQ=FW1J>LiJp4VVNbvQL&9qz8QMYD`HG0-a z^s7$de1BjNhxGF$+Zm2y>+vQ45-`|peh(;q(abE#RzzdhhYv!2)E1tcNdFq}VUpzV z8uL4D;^Vv|j?o16np!Nry^HDCIZEERg%_X%0%;sWV~3dq^H&~1T}S!pKyEP8G(_xo# z;Qd+(#pWOe{imNC-hzu)8=pZA3>LaeLdSiUa_*2(Fjno!>It_9$dnhl?6429+*osTc$PI3v2LpCn!d5uzy4t>3^>a%z1;| zEU4Ir2B;G883)P&HnhGcCh5=`m*~fj$xE74x)p(jIEY4!o<^kkB<3m%)1Y*cab7T- zpS*c{$z+V!q>y`-ns?!w+H>$A0|xAB0R=od`PX!bw+ipvdV@bf)Qj-$?$jS+V+U9Z zU+3r7AXIot)-+ncQKr(*12#cKc_;~yyfhFXHgS1Wq*>+d`F$6)a0@o({UeW0o8^wa zI3qg&2}d_VR|PIgOp6koU%}=7{E$uGv4&{zO(0#Gu0=PAh6AMoVftTL)+6QW67G;i z${{7b`r%5_6t0$mjg1z69apJk$%(fuTttUQ^9{PpI!M~!SBBfp&ThqtEi9dpQ z3#_F{YdYl8YH z+rirOwA+ite$wK-Lfg>kl!2TGXMv%tT8!*I>gR6;7_G)kh*2=ddp7b77Q*DwpUFN$ zNAUn)3?*?xXl`ie^08~z1PG5pi|3>E8N(#?*4hoGKt}(oL##i3yy{y0!zZrCIiO}Q z{3`j)_}{YVWZ&B(j-!l($pZ`wxZKPIFT?)rF}!jEf4+^De$CHeq}V(`+-lw;{TIv; z`2G%y%6lLeGJ3G%wJaGd5)#P}yKPrjT)=VNxgt5;9ar(9g?i#czxPjB7tIpM+SaPIp*3R%T*@+2vf?+sw(1s1Bg(XSfP4@ z*D!P~v+pwiOAG(dD}q#;H;bfB`Zni^ReEjpX(ay$u; z7hH5!W{qr2(+YLhs~K$5LQ%szzjVWd=)x|`c1%b55xH?(?-MWzpnIX~9KVopd4{v4 zfN>N2?>Rf$P>~W)Mg1>E8ps2}y?XhQSWkZZhE5L*%neNlV?FSi&})r*A{V9W>2N`m zR^v8TOf2qkdY(U>j%gK|+buvO^5PIItA%{&$H>ja1ZjZcyBYO7(;*1T8tK{?Q)<$x zqX&hl!G=Pro!^=}h$*B8@k79 zUhnNpOvC}Xhlht6YzmUnUM$EGe=lN7r9=GlBWr7?{z|Io-VInzUPmF$Z^(qG`WDj+xUL9sA<2hsRRvQE>_+u1 zEKe5CcygF+m8x}wee6Az!kM_;)1fP_>dEb?Qr%Q?*scG1@UD1fr`Ds|%l=t+FdupD zkO3&pbZZ6+g3$2Yzlx{o6ZM_@zveLVqy;jt(t%`rN&*gr7(UVWBKlvoh^|)@&3XNL z-J}=QMoE|S=aBPMgv_sA-2}HTOs(m8t7I~Z+C*L}gU+tiG(Xj^OgSI$JNZ7^gnzN6 z%Gt9u2#QzOIXNA3gaJCFhGbCVIk~@y3tq6~1G&9)U$i5o3;=cHOJ2HZvGhwUDaGeV zPaor=z|sekC)5o2Rh)o1PrFuYb^fLdK!x0!RuP+%SMF#>~|l)3u+!+ zV{k(MTA6y!>EFSZ=5t4VSl6hYr7+a$xen+EY^of@9t4L>`&Rmh13FuYiwuE7=Iu!x z($kUQ{}*>}8qamxMUQ@w2&rU<218~INJ$|?q=_FP_v??~jxwGdGB5u}ygeeW%Ll^=f*-e=kPDc{Y= zbm!;pZYnkoz>ff(eMx3PinLf&njRHPe3%gaGS)+k58zhpI6hQSv3*Ke*DCSf_Rw&U zhzPLcVE~9^a0ro+9+59|!bpAGMco_Qb1mdv`njhi57UkB`_LM$Jsci*1Ta5!$~&OwBWU84rpBb8ma! z!M<(u0tgHzG#mVXL5&Acwgp9YI6;hso#8<-0)-bg?NqIw`(om+$#Q=YIV}+=xH&!E zd0!!0AA4QAfP?LDiJ`ZkkRo69qsLub0vWZ7wK>T6%w~-LfK~eq*U)=?!>E>GR}NrG5cfV z&Yn%1wMAM1y^*dR@pEyn1R6na3I691;<_R-o!#ePvIgTqRZ?t%Z=#ZdZn-$4>J58_ z80YE5P958hxx%%f;O!|6h4vSzjy+{2TN>wndD32~m~FNfXa`&3ZwI6CiK}PNaDC-} zbvODVmQ-@!#5g6tu<)1$4q9jScY}cUPReoBLZzTKK)g;yu^Z`?%9T~2A zAA!I;ho%p~Gxk$$)QId&Z1NVUG$O9)u{^a>(16&PI7Tcx(g))BZNZx-VH--U=en)yGohk^zh|IS(uGfu$;qV&ao~&`ceTUk zgm_W({#pbaPdRyf(BAa5q1!caI!w2?6$nmik5o}mwF1mLp^&K=Shshd zEoaSlwt8E2)Z&V-J6@RJdI6CiC&VT{q1ed`0UyA?F`yyL(0wL+ zA4*5TP{%Hm8;C#jULkIMnzDi1G%**EugmDM12& zC{g&IcEOl=Sd*x2l7TL{MkQq!J{pAzF5l5;WLNQ#EA2bS%#6^la~QbndC`O+kH6dy<@wvE~8gso1_&n6@4s z#=4$`hmLFRYiw*p!;uhPP{P^n`+hjPW^Tt z=B|E%l!!m4OnGsUid1L#_+!j@{cE9*x zU9}FYojtm_db8<%=K=q7d*3VXBf=eCJdhl&F*1V1W+{#ymMgb+R2v2~=pQ_Ikkp$< zWbnjW;0x%(y@GjjYtL*_dO#s&2VxLdW!>lIdn*~TEMR}dXKDx^pf!iBHxy>Z>}MP6 zX!N5^pl@ry2TygK{F=@$n-t`kJ}3seE9pVm3uexHT+n*xT8Npz0O z&ho2!LdY!S*>p-H2YUqCnS)vG-4&e{FV+SH4W48%@LA-&A`MI6fkuEJwO->GVS+%I zK!nW7Y$dAlp5M7I%Abx%*Dk(b9mNj$hVG2zTlURGni@UxZtsAZrcKwm8yae$ET|ca zekn2tEj$lgFG;6jMUOj5u6l%&qP9O7!N|YiucF}+8W$%W9~kYpP`oq}3}ky`FYpJK zn9GL`A8yMYLHkTcNs!2Cj*pZ~)_KDHm?e6gPxU*uhp+i@V0wZJVIP~_^=skBjAiNy>KFk+85j}x*MNpd&DVSywg zW9g9uYn#{lIMmkQCrI=R{bI&FYoJ)<Q%P+Vl;qAK*9lrN#6wya(spbl>|mBPMkASw*BKmm)^$w}Yeh&B~5wdT(_F_4yXAg!!&6G&~(Zk&3oLsm z{BT_SVUHK3kS66acCu9;{*Otm$qRx^G6%e{;^PC>{g%#v#}qOUTJh z`St5pvJzVZVF2JTTk(SEBw98GAs7?UG=eBp!#!-_Kku99FGF7J~`(gjc2q{7{4{t ztWi*~RAT=8JUIM8t&04#91CBiAOEr!yJ}@@2bdq+^ z&6`V*Yn{z?f&}pqxM(H6U%LspfQKhFNP>c;Tm==4FisxwXU)%lIck;-*R1o<-ccS3 z6R&Bwqa)On$HVg#wL@{hmTQ{;xydwv!eKQZfvE01veq`vIOf-_TkUd3F>X$luV- z*#^X&M42S|9{#=)E(EO9G>Ps#*;%aZ6L+IyQC2Gy-eapTii$jb+$4V^d5d1SG#l#- z-B5Sx3z-&V2xzyuDRB>kL`0ur2bgBEwZ$ zKpZN7OA`dlWECy>&G&h}DK^y!@Py=yYdETw0@NY}%Y^jubW*Dwocza6d@K|75G6#nh?Nwuxyg1fdu#(7TsVya>3=*UHS8ealL#eU` zdNrE=v1iX7YTZDQWu!nJ{iN^!lxULwaP%#HoQ?K+f)(I+EfC;0w{2`5yDgpb854aX z&HTrEM8x;c#lt5MzMd9>`&C416c#TYhYf~J@gS-9t2$4s4Q9_7J-3clWO?3+cv1y zteg~7aLq*X6XJ)|0jL%`G8NMKA|yXH&c6O$);M_*y#pwgK7tXC&~m~PfuMLh;U1Y^ z)~LbLnI1pT@`p%Tz9KDzWg_%`^SYXfNCc7Wc~MP62oM=ddvU{Cq(&retoQ&(kW?f< ziwPK~cG$XwmToBW*53MThy>TrykYDudiq;GeL6yp9>`uH3(^5zZw$j9Z?rC%=gfg% zjDK4f7=!YP3O&z~#i^RrJrukJFGB?uj0&gP7x<1J-jjCS+|1&AJG7PIvW%jRj(4a< zM9^GAGFoDh+uDMuvkOWfwa>{na3Iqg0nRm`qQ!Z|HXVRa8I~onkyIQa#~^G5xELhg zIluHJJnv?za@CW20uXLQYXLqv*@bLvhGY8v}&d(t*2wY4os@B1Cp>bB1YJNgF zNEj9|9N-{>GP$PO-%d*_YuW9b_rcaQ-pECJK_bgvg$Io$RHWyKhXHH0f6C~#zuP_41b|hECA)ufIJz> zh;a$lQ0zz2JRqbV!5b8Z7%3PNxN>Ov2^JNpN>DG9RG)<{hb1IOX5QgB8^1gV4V^?* zS13MD*XCjgURf{gto(T;s_YpZhEP}_V6p<11<;5QKdbe0AzDKinhY)k8za?-b&IZT zG7h%6ge#!SNML)Ne-1zb222Zqk02Zt^t3s;ppBBQfYy}qKu)k|DSc{s-GfYQ zYoWJXz)&kkfz`Z#UOG;J={tZ2343LD`6#&@{zrqy^&BZH!V!{P|i5~xNJk-E{bf)K<6LAC6prYUM{LjnY#>X
    #G((Uq4uRziqR3ws=<3rxWi9RFOj$;_N~JirR{S3qfF9>$`B;GV2b#R~<0 z=w2Y~*9GoD{V4b*fXB%=1Q+y35WGiU#xta0X)D{K^FOU8#t)v$|EE9VKR-o3_`v?p zE#NNx^OeE*f6_PV&=2zbToM)y(4>%*la- z1^sAAga)f}0J(Mqj8L%D!I2)PzD{f_WG4T;Oe1)f4NIJ;c=U*sS+H;+j*sNrv*?&c z0v0rp{raE#KN`V0nRD%Y^x%OGW9;ngj2^?(wtB9w##tCnX;0_6tMcz_Vy7!HcqG&{ z$1D11iI4CXr9X+te|}j;hFvHQpnt#qKW-aYQH-P`Q_?X4$Kl0+hZ+uhTU2#xJsY3Y ztQ{-b`|*gwIN&OZTWDEW1M&}P`INZzcd|cZG%~_ zJdr+<3KI&}y}&#G^)v7`h83E|>yA-DI13;H`p?1q*L{u13XhME zy2BxgLL!<>YQO)+HozDeFJuO#s1d2<8$_^V=TuZ!AVKS3`%EZ0@*G*4;^kD?KoJ`I zg{FFuwGlvC0`*X+5o=G%A5`O!pQ6ASpwg!(KuT!P;E1#jkT(*`RB(I4#k?s1dX)}* zwWD8-IzZO8L>)@Y5UjA+VStZwpalp4-42q7JOn^S5@_Ul?!>=qTXrQn76quyn1uo5 zZ{q9(aEFb7!gLly1Lfd4$?1YLk6?t?pxi(Cy}4-r;^N}s=>h`mNcPb$*@jC(7s6C5 z9q|hjffJ*me@Fmm6_8i!BgaBxa43>kqKAh*OpMD|?m+f|NY6%HKVbMt-b{Koq7b3` z6_=FMbs5KLP}bBm7u)0!Qh0iV`ObB$I9<*!-6cp3TwPsVbKJew5d8=h9d-XxfeZ~2 zRR$#JB}0ABIHV|KwMj~C8VZQsKw7~leZpZST-_r7Pq>_MDsA~>abp&6un$nK=An_1 zN)6Orvy>cklVUn>8ICeMFAqVN3QX_Zf)+4kdP1p zlwd!`MeHjyqB-XSp|GIog_ux9Q_~(G=~vo=cisgDc(PYS1&0m-<<}ka|JmHh$tU1J zI$(@aRs2j9I#j3)3pp#5rx zRWdbuI;a>f8YTt+2wU=P0BX1YjDzP*+#i#rMhD0M%O_|K?2o}uv=S-@%+OcB!a@lcBK!MZ+jAW6Vp(3!g@thoRuT*;h6}cdj!-T?&kFFZMPpET@pYmfE;2)ZFimFU&?GF|r1a*!bD#6TT3 z01wc$)O=Jf#C_w1>VSL~*v_)!6J&3VS&J4U|BxU5fYJ%wei+oqO%4KB#qh+)*c)VI zdbgt+CL|_?VCGhppu{A*K&MMk6Pt|qVXbx5r*fnJl6vRJHUZ99&P|0&afGDsTl){R zPhsq!ho(lEvrd61+t0qBNuIKCi$3&|1$=BTFjR!jVVnQkm2;hkdjt&lBTzBIIaL%T zJ$%o19X|U204NGHG7+eD+t|^KXg6$m7UvQ5QjjeH3R4K=U+N_Il0{}oRSz?^h0VYDf)n8usst}kV=okUT$ndU!> ziPERBF~{LK__4QY)4+|!w{IN~{==+h{rlGC`Chtz($V#mmX=;K<9y9pX3=7r`pD4a zBiD-@48#TuLw!vji-BND$CW@u!E`SiZf|djsAh8#2DT-P%Qo}7heI!-&jD?6sgo$+gcqpqudOF_)IYfSxWA45M%5g@pV5B*o7)@v`6qB{&Fsc(eae6(U=ZUunFBS1IEkqJP9Z+)=G z!=p0@H)J52=Vho*=4H0|CZ$@GgwH`je z1U!YYJs2v0b`naP0XEd!wA%z4ocdg1&c?)CiTMQxzRnXpw5#EGM3_7N&_^H~)Ef-o zMhOk()J>z`(G7RNlF`1E$YnT4DS>%N=<1_B=@h_MhTlJTb;00>6Z`oQPGt;+;{Y9t zC*^>IE27k44X&E3y#ByRCH}Jkj~+o~10)e`HJmi|3HYaxfeZiiy~WYHyp1zuuo1!yp&$QvB43mKHnmXvgbD|F#}{XzH+5=7^F9!e3v> z#~p)Twrn?PCnAIKz&={giR&YqK^6v#QRrDgzf7s*T%j*2fBzAQkjuSO!==J zhQ3~%KcwwHKlOj!H>x=R3L+rjFSwstp8f9AU9YF)j>j6Cwh)mHu^%r}@6e&*_5l*o z!v7TIHdLotFOE4g|z3FU1=OX-49I`S;3+5JeG!Yh8L^X2PTmC14|fd;`9(C_7Nd78Pe$FZe+jY z(?VGcz}ll}HIB+xt#0$;$Rpy0`tp&uD5QmlhSIdda7;2>w)+i(L%GTmC3G3ABYQoO z=Nz@XMV&vu8y_M-62MFoc^LBVqP-7cAkoRFchgKtAheeajzBB<<{}=@a7*P&kgT~i`h?n}v4p9Q9%K*;gSA@$9GsRKc|r`|nko!>f}0^uL+=8Zf6E$6E2~}k z*Kj&9%*bH>Vzq6WYa|$oVyFbj?JX>fon}a(@XUHLipARXI#(2o57a_r5DL<~Hsq;GkX8`cJ*W+N3Sa>sInHIlM{5|pZm_T?!9od|ys zG&->`u;I#5x#jrj=P#I)&eni>jg$ox_lt}>ipRc;thiRp} zLjN`J6orB}Bz#7Q!v=|jz&9&4syMq2Y&!usSk=_j2Iu6Thve@n!yg++3}f{9*-JR7 zfFELI$k7cM34*@{B(76bRaNn3TR^mvnZDmPq$ecD!tz6scd?_G2poXy9%#b&)3m$4 ze;&niK%s3o_KYs5LXkuDHwqb&OOUc+2yc?iHqZrHikggKJB)4`7JUI`vv+vFiO5S@ z&15t zoJ5I3uTcV~4g)xab|eC;gYYah`6YipJqV`eBKB7Agd#%+b_*U*zo-ZNp2c`T;tsWy z*$j1)*U{puNovIeK?<2%RM&$;Lt)=yxd~Nt{Ovxs0S^S}502+BAUfKpv{3JvTu8^8 zB>aQAUvM(e$R(7N7&Fd`CS+cM9&gWJ5Y|{2*hp`_^?m%;J9k6v&1eVEQz(71P=ri3 zaJ|V0R75q$ruic8foVI#5RO+Jhav+RCmVxVcgL~I6w)lVZf!y4XXYJ}C5No;Ppi=| z%8x{VJv0IrWv*9@yBP+e@`&C0De+%Vpa@Suo*8h{g}958iFg(I*?aMYI-p=Q7n$k) z$Xr^ve`8JusZn0uyYk!i(b*tjc6N5T;G+Q7NVE^uoR=U!^mYSoGQALp{EYhf>473w zg#?J4*C9Gsj*u7(<_s4s3@FU$tF;f*#;O7zj#Mn6n|zf z>$v+n;B{B`UTRdLiHE>P#Q*nxIZ(N>!|X{ESnZ!Q7SA2MDtaleLQo=+Qh}t_{v+Nb zqCmb(D(G*p$!h~bR{dh2@Xj6C6LIHb@fKManhvdY52qvPLork$>DLI1Wf%q+!4n^w zg$tj~*6GZDR~sC2rnh?G@c6C#`~_e4{d+Kh2}8NX?g;l*2r)%fFsD)2=!`Nfq=;Nb zTq?(B(qw2PM`gfS$?OCL<3RSC#)BY{f+lwPpN#{e&&U?8F%2*7xFlVZK~A5{S4^<1 z>PEe9ux?!lf?7tZquDSTaMyRJS#Y_#*mZe>4cJiv4aq1tvQjHa>k$S3)Uzh6=+h>!dOw~5iA&>O{XLII&qlK zV?dcmz4PrDwJMqTm+a|!_mBY^x&n0~X=s}~cgUdm2>QH-I6;*#xW~9<$NM{biCw2o zZ7jSF{Zls^;KRn0{#6L2x=@P=%2kFTamdPA&pg1fJPxoS)r z&6L5p#m+!Ot)26mRH`UAnNpOXK$}y~lMohp23Rn4dH-kd6lOD+rn0jGoDmHWKriBd zmo6iZ5VU`=Tfv|gr8u>CH~=zA?k2eAw%c3;=5LVPwV@JYx1nH$b_qd%J6ZsoC|W|% zL@HZ?rK$M{sU=P8b?rpoL=c(_#NfQLsc6$eG+n_-g!*o&#=}Zl4Q&a+)iK0h3Cc?# zUBc5(^|^%TLN>EhrPHu>jVk&qWdG3=j&B|snm;_h~INJU(eSOPm`KJY}xVqzvLoBj3m zO>_1zfoZjl^?p)QLM64tel({Jy0XkL%1a#le(-06jj?uKFndL?;g)C#J-MWF&#j_Q zKi6SRV%Iqsulap!hUaN}>8X6It;0DTmR`Ra{VoqmitqTOG1|0-WCfE$T1y_uovE6R zSA+=(Qq$ zzVRQ-JekAP$xC&A$mLMJI<~cveB>Sc`03LM$TxPaOjywyxws@~V|w{H+W?DqN&l5| z`Mfs>8^kO-@%8ow$$6LtqrNzeXZ9hgG*Op zhVCNX=+I$pNxlZ31lRstWBO>J_OR~G3|rAMUGWA^ z=|472e?jG+gUT+e&Y$?W3Llp>`O|C**@Sz%XOCMi_gBAmt2}9if%Q97hAn5+ZVPc6o8f%6k3~MN)a2x6&d$Y`b4)%2=e|yEi7zP2 zyXv)FcHXYZ7T%m!PK1bUmJ9XjzL-{eXleNl-c#}O=h*adSiG<^`0Ts7J?n^pc$kze z_ls1OuKhBJ9(BB1l?UVnX4aJlPtp1Jtw|IrMA-MX{OJ3*tXD!+|M5G0t!7Qu>ek*< zOMfY7rR2WaWwk&sdwclKgIaaB^$A1c_Z z&-@{f#`kHYEf+U;?@3jW69kB07c>V>_-A`nd_Oz*0rf!`+2rM`Ek0!cm#7C49 zVU4?wXg~#23QPy0@2Q~lQLtcGumCuI1+ORNSwp+F%%wiuQwE_xu0u#sg^*3*G{(_)^OG2-)apq|E z`@2iir$ZRKGztcp>-!X+I2mq7!=Ia zi9L}1+j9P&>RFnRhUzW_u4o3;Q8{#wU*BwxYD};N8T;d)24*YMfTp(z zLdSGzpfTBlt9gC0L%oHBA<|(fzKM?{5XPyeGgW0Hcd&a3eV-l9xP9sa6|DZCE)m_F>|W zxnt4(-00o;=|d}!j)b7;=eed+HUA+P>>T)bLbYey>hH5HdDKAu@Qcom;-?K6<`e5e zuE6&PUcadFQwhsAPV}%|J~%G*#e3x@@xhJ59WOt(?`sZ?R-%N1*&2Hi z&{$Z7gElxJ+9c1fS}alj_HZT>1csv7i>IPvxy<_W9<^RgksxPa24GfW4V)IjbsQAh z1y9W1A~hsi2X%Cq7-cs!Xa@S)1PnXbOI^OVIVBYs+n-9gdi*vkzmz%}O!qTV6SpFK zVakP}?X&=ZF7%0e{{1)#g(UQ)tUrx$M_hW!FoHQOch3X{cW-{Xt+xB?=2fi)-PH@b zUYSp@wLJ>0c|3Z6ePCzK;T?_h&b!M@641C%rt!(&K!Q!-{Nb+F&FM2kVXidPbF)?v z90YQd<`*UQse>^$2g$;r_vDhe^+SM~$p0_HM4pVy@B(vR2^jon0mwR+=ns_j^A|6! z{rnz)CfO9wH4B^#fN8ff2*i_k@-3wIU|P;>(tLok5v^G#(13RPQD0I%az7RZzI&m| zQa4*<=3Tlp?UQr~@a-}nAZjM-aMBj@lVk`pn|+@z9eV;ix+uFYYCHpsz4XFn7vf<;8#-BV2H{x zD)c)q`E7KzMXkcvp>N1wox((STz%@A`#`W%cE2%_{c;zSy@=c>+fKu z1!$K*OB_s(88|9YeGmwdFex-FEOuoHF!x!oTE<&s$D4wz##HS{S(y&CpraQXD8^V* zbG1g^;P=ho&}RSs1TbF|h;FBe?HcyLpuWl8!czq|<1<^h`;C=f2Er$X{_$tOtZEK=31QOiBdW3Lg0I#`L0ewHk%M{prvT zZznm~o9p5TA8Y_gnLpQw&`(7h+6@nzSf<~f8dZ#ZzZLj85qV*gl5ldMECtmUjxEaN zG_z)lwd;=tbK8pS+!8M4yr%yOXj^uvj5`QMvV-7ZXR|T%?`iEa<1hE!W1ty3PT5dd8@G z;5Np&@sH=cj}v=RRrLTUkkpTxl>k}soGNI`3JFzmDE_lIX8L{*^ceo+dvs33Jbu-O z+q8aN>9?smI}&_M^Wc}jpAC5DSd}^kysV{= zpD01eg!Y05eO#5n!K}nKtx2i-WbJ~iKm#32*{aT2J`m1Or_ryzrbsA+m{JIRNKJHn zw(2#*s(6ZYU=Nsx>Sn=nvvUlboim#1@1Ug5_PC1I351CTClZS(zH?WE_MVoZKIi2nD@Blu zTz!^lD=e09X1|Mdy{40-dQ<>==VhBCBdb=oeywJfKRX~Y&q_1&VWnwT`YI3JvxoM? z_8;!ju1MXrv>iw(IoX)V3xL0J_6h(zj)_zY0q}=I?Pz&%6bgWEl=Gpur3rdXlj>DJ z911~pJeSzE692)SYnL|N1ItZgK&VAj^Y_=)_Ihv7D4C^FD`b1<@RKL&8NJcMUMNz4 zh~#*@$Etvs$N7cLtLkgm8%2IinGLSL@j>KfV>$|K>1%7ar~@ASc?@`>=_lBU_)b76 zE`khKXxm|SOzfR?Vj3EA-AjhkHuwbu)IBQ2p;Rzh@CBPS^^7snPbNs_@}c^^`|xFW zG-d60){WaEm|w@oU~8xb4h8d+vmax|Ts3 zOVFD}mXq$hx8Iq6|9;l6R$;*Vh^J%t>Aad>rKh+3+G)aw_0MB(=-TL0oB7Clri~Ss z_w#!_HOl!i-r&45bs(61+SnSD*UG}Au>;;|zT7^=~NS2?(w8+$W(KQnv z@JFHp8=uA|(Q7cLF023^ldqZV*tOPU|$6sldi7WW^Z{UYkr_|X27Y70Ani^>H z9kWe(<1u&e>Z9T*YiCqS$eC-f}R=8e=U=>aQ$?%IxF{kPWxz{9J!2? zBbV9LU}HTm_p1NvhO+6L z@4aAyVwlchEM~q{<3scDE+(baZB1b_w(-qG_%53*HiE_h4U$cqP?p>sY?)X%aQv5%%V zu;ZcCgA1eK$uQIBg5es-z(+7FB-w#F9BtvqP>NU>HeQiBeeRq&J!opCK%^Y@RozW%ZwLIsg&Odv^)Gt{}Ltq^F--*}QvMxwDFn#Hms{ zMfK^&qCYz*noP_TTa#jZem@X!qQvrmx68&FlHz#zPUScx?vWL4b1wV;5~ksBw8#R23Z5oehn(EBkYl2hEM={Nk6nOBZ@lZd^4wF3z{s_G8<(9XE$B zbqAtlnSG%zy545;!Ng0sTrIhq^Dn$oio$?weUXn38;13Blyg z>j#2N}HGjopU5zk|@F8IM3St0z!1s#rPf3K~y<*$cHMYyp(#DiR6We=fvD-9W-V%2u z(=}L|3`5G2Cp&IF9fpvnsXzMa;|{%!na9_B+;+kpnro(o3SGe$W6Bsd$wGTvH^6OIrw<&;{ zxE$z1sDI6zxpP^Vu&^+nN`d}P)4FjxBqIlq=1@M!ltOnFHpbl#8GUfmueMkk($H~I zHU8=2$B!B2vEz*u1#cfU+ngFSl#jK1>Ia0oXe{C~15C!q__j>QV-`;u51A(KjOs0l7N*!ZB+r;^wh z{#dbluXFc)JCL{aP_387aQhh&m>|RY6TL&GX$BccvN9Ao0f`sfb+hYWH&Of0K}kf#8TjIc4FYe!G-Puy zF!{@B+FQ{pkdp;OWiH&tc))C;mLkvJ2P^xFx6Dm?vAPBLdnu9^&`V(u1AO>YymbMc zB1uNF6YBFGjnwcPhP*V$&{Y1EdJqriW1wteA^qKuo8O^{S znVeF06cO>&W1EN$Sfx;LNPP}CJ;b4Cg**eDAgph1yGIdIMZFY;sfD?FP1gWKp#hMA6BNml6WdGvz3<1Grz#xsUiwTQjNF_Vx-otAzGhEc5TG)ySVQ{6Y{`vS zr6f#)ICnB~iJKy-;F+#soFc)L$SU`IB0CZAD+I`*KZzVonFzCaK4``uZA$x@O<8-E z+V4G9cf&AR5am2|O)=;j4DJ|$jo~7-k?Z3Z^0zLamwe6k{z;^rQ&OY8B>yiP^kexy zTu^eM&i78u_Iq~|L28J!P_kXJG#USHY1H`sX%mjqRMMXz1TR~A8$EnTY0J|lfV(K+ z^=(}sNI`9xLh<k;VGkz&8303D2=PfsowF{QBCwQ5e148 zp3IlOKt-Nr1w#ioDg|QgWu0hfL0=CqI%hU)9oErxtObHk*%5i+LShAy&}vV(NiW%g zkaz=byjm#QTvz_x{FuyfdVh6p*_Xcvn}S&zrVQr15T80t>(_U$#1995_sN_#Z1L?; zIJ)iS;~v$ikCvF{Y^?EH7Up0#`|*Wvot_TXQX*O7o;|P6Tl>@@CRjPwp_K zjW?xrmKL6Qnz~@e5>Q3w7o}?7|H9AJ=^ienV7q+rE>V`-9TNg4d;RbAC#<~pg8zK( zj`w3@Gfj;zZk*l>dqgmrH(VAsVIAghb#`Nl0re_F#$yLn_2a6l!w8;GzZ}Qy(?D3@ zD$73N^dsXh(h3mkiINx)8V#_dNr#{XZu#9qQA5f=$bRIPq+!*!AU8s%SGHn+#KHVM z@S?x9I>Wf6q(oL$7IfO*=SxO}WRfPBuTty8uz>x)HJ(cKPiM)C6iJIw$!<@Se97_b z=H>7OD_-t6{&+2W-WKLcD~r*I9?qw2+aCMv-y8Ot?^wljwUt`$7fiabL)_>@_mQhg zwq@RZZ#A0O>c&ifPa+#AMfE4&6bPjs)VLQm_~HHu!zhZ-mN?b2YGOFb$Sway^*UZb zZcffS+2@aDc+PuukK>=SVBX}2l7$cuzB!<^vpD^3yT27H;~hUSuSK?zr+Xj>E7+Aih%dQRxK%Z~A1F(%D9#Cr)W56=OWl2>{dBt`MuFQ&(l6!WSTC8 zp5Ev%!#coM&w$tR$&RjZO`aDWR?&mICxd5Yf?OT-&+~0D9{Tu5Iaj8?e%2x{+Zm@r zE~FNn8hqBjY^|Q&2f^vbDxS|Z2$I^w#u)KBFO=kX+Q|NV{O+_*R}US%w@qxyw8>NBd4RS`WMwS!FGK@W-r;w%V;d`=0kkaO97Gyze`Wdy6=nlN8`J>!GXDo^W^% zz!E(hvSGfc^}<^uvf<7CIZWE@Y^PjYKog_YfnkF*KAJ+JS}hY%Iz=}e1a<}|Lr zw~qS-l8U!T&XN9k*uWxY&Vq#70>h$suFSxU~9e+q27gZ6~=6#C`VPSYlo` z9y0P2TEmgM7C)7wK3rengC4Z3x{JF%8TPB_$<9A0)39Y*i0ns=+)zjZ-(-9DB>`mEw}#Z9XlgGDmY-WFyb1tvn4cn#6(L zck9R>7jiEkiJG8EWE$z5u8!bj-{IU%D<`Wg?Rtej3D(nOU7^%*KtkHeH%w++_B(e2p~6 zms#blx?Q&}2YK3kdCU99b)YM0zKB4YwvO)>MzPZ+dhnG_mqVN3r_cNSezyEN6I5qX zKe*;r!&0fA{)gY3zL8b9FWg=|aO#aTaU)(AtJblY9ZK=`axcrhyQO~=wHYrDjQ{!i zs>vKZ?YEkbWj9yn%>MEIS8(5AY0k`Q2vMVBW5bYfTf+_(0FIG->;hFo-D&JM8+mff z6n@WRiwMYVxV1#sWouyT95=yFY6Rjc798|RL ze{|dvuDw+2_OR$*6n7D^@$2#_CEPzGsAkFbBYvF}Ypmzv zqGgkru4Ps-KOZ(Px_*A)^(&oa9$o4@;m_C@-9@ekE+%Yysh<;?e^~VRH@WE?zB;e2 z%v%svm>g;=_voXL{m~-EB0Cv|wL8Q;62|!l(igL>Upne>t9!G^lQ`e9fMXFhbkpi99-Q-sVhxEPb%N znGNk?OmV)`!ZRWJ9*6bt{ID!xI%US<6{d*V39?PNWo8fE`_nBXA3wZCP*U55iQo%* zE^PPgb!m(z_sAX@yW(grm-C{GmWGq$l~adF2jW}1BYGJZ|0lWM`*y{R4qNa@w-*nw zy&pfocH`_+vD0f?2b#Da$EoNWw8$0eYhT%1I{mW4O_5En7R`BfQ|7_+{VZt@P0qXR zU=!r#*SKK$Vco_$*Om5CZGDrpvTq-g8NAQkdh7F^qfK*I7~YS5I(0XNy;p85J2anL zf`h$%TfE87brwpI>=!DELN$c-c)9N628tq?W8x;&Ki10byYRusB#=i}zz)ynf?;Yf z{(@u**u{c|>s1Ofj5o|YraSFPb0+&i4YPz3xeyv%kaF7TFc`b8%O*FV;$r5;|5t@j zUelSjWxi2qMw-o2ZV3Lw*&D2v9X@Qfh^KO~&UgI<$M>r@S_z*%GVXDHXqMw#39)Yr zEn8imJr4QyM8J7_TcVYhQP-Ee!6gR4>IzHUp~NWL)gNQ4u=J6))Lh|A_){vI z%=Mlf^`JGZH-&Vqx6ef-!gv;W$;|gFyRA4!U8YAk&b!SiuG$%qkj$wdBlB}xH5xWInrhPXqZn$iR)cyV^c+KgN8QkkeF3?5gjfjN_EnW-M#3wPZW0YxKjNJuhnA z+Pih{uxC#-28x)_@+ne7#ydT;{ko%2na6;HO3k3GDgER~@lNBi^|$MmW$s+=$ufL7H~6aQEXk>d#D94NKh|bh=O%vqSmyIR%UPec zZDoQUzYxZXx2v0-^e!Oz$=3Cqx$|#+P>~%e2zAif^*DL8*nmf{!J@EKEshrr-m_*d zT77?z=f}H^rJFOdd(Pz_nmm&}(ZRqCrE)2_iwht-x$laLXRd|(q) zcd-9Nu-a@f&sz)cFWCGzalF>(=tI3@Yno1;G3)#ev6$-QW9K5j#wasgK9{oUHU1O! zm(~fNTD2Y8g*OxL<}wTJitq>AW(oVSC*x%dn8mB|kuf_;PK6!cxaI2G_H81o@q(3x zC$B&Dy7y$JM_Ah)O0!QNP5Q8Z@>%N*!xx_lB5Wo`9?O3DzG9%+I;dDx$i(HzMekkH zR^R^hRpY6)$;`hV%O{ZqPkTDMLLW?(sJmw)d*Q%=a`RqJ)l{y#2j{lEaCaDv+cMPA zZanMhFV&7d+f{RFm+d+BYh3&5=1gN}m6jdan!Kqt74JWW?`;{0HLUdfR-rn*tcFJbM6)qzpRR8W;xPMi=t$z5U< z6>Je`x_o2N?x-f8jyG;Pp3@nw!2Ez0*M@x~oA!Q|n;-UCFl9=Z(~m!5)#SoW&4WI^ zbsJ)>j%}T%cRzWIqrfRfNGf^YkLuHeQcJWlk}YHpYwMPNTIW~6 z$}eJAogkmOd4Y%`UzEP3`*=>YVQZR+czrpyxu;iFj9BsP%ig^ww5Lk!dCQiNFr($= zta6cMk8bbf`1;o8i;Rf*G(WDBo?%b#9C5qk-1X_$w+)sp-BSZS`}MPo3xDeO+?vfR zQz|Pi!Zzf#r7iX7ur{ql^y;GL3jVV8TQ+CKLq)BFXO>)<3 zE(akFkcrM)2UI*hs7yUQqItcGNAZ%Oc5ALx(*5mHiQ6P*AXk>Vdv5z$#W@?V9yo8w z((uNuEbrP&8}UD9>W{9f%Z0Osw+NZnSgU>9XZU4OS~{MClt>(=TNrcK{-h)pryetzrwiGt5T_xWmZCOT#- zN#7JM%C#vQ-*wLZ##Y6|faR8>f_roqoiX@|&v_G*aMW63U2*-G)GwdIj~lg@FPn2; zw)kwp6b0X=8!V0sXLDBaxeFN_xfPPU$;)w-DhD>Y{JwkjUJ(t3*$tsM`3I0K^uOdB zw!ONhbx^6}S!&k@zIr3k);lff8)750)o#T{UzrqCb~kpmJ|51#kgDB}PS#W%z2`aO z?uxnlB?2BB=s65!G|YM_#%FA!DKJe;A~~E3LO1;E^Vs4!$F1)3t*cE4t=hb~?dxFZ z$2S^jj_yG$Kg+`=*>{%OsW)9wV@mCm?Yk~Wy>U)I-g$4R+3`96^ch^kcO||ih|k^mgjQe#x1P{JfqtuN7(Xofa5z&Vw+$(cdOj7A7u)8 zwQV1f$n_=TH)?_gVuAw9rbC9*vrTtMIt+e;i%RqW`d-l;N5Q(dz*jh0zkLYgEVU~u zE+(Aw-ua`8K=V&k4?6&P?mn#I8&PnfkY?=Bk?y!Nw5O<56~_UNVygqR;fOV2Ay{&@ zSFsD1yx2;g*XOM8gV)CN##JaRsqLbo&av)<*5>MDn%XEMB0B6;EftUqOlwGhBsI>c z$ZU1eq2T_;n^R1y017Xxu(sSsTs*|To1JmE&ks7c-!?6FIc&Zk>lmBBc~x=GI3 zB1eE!SYq_2-#$`GR=9r>Y4OAV!)B`}7GO;vkk8S4%TImLf$eDz)pN7Av%$ICyDtqm zzvICf4iRtKM;7K)W!YgMkP;{nIUbV39u$-o{~@f&&vxs{trqFao76RfCNBqJytFbG zXaS$f7dFmM${w~?um%I>&J|@p<9#s8?G59MZ?MYu)WBbVim(xa8u1(beWqd^wU|R! zr3!UIr_dWI;;S*ePH2Z)i{GA=dd`(2z2H1%O zKaV;o_q4Q2_K*foO=}+-EcW8nZ$~ zk`(L7uzGmEuy?t!5G)O^p)UI=3~*^d{qu0g7+$3ErArP< z)YM*<72S-WT%u{0FKdsYQq3~HbA$x^ux#IIOJ0`L4jY&Cll?W~M51kWx9593`ugQk zlTk1Pk8Z~cfe(9BD{nC~+L4WWzG$zdW!BMRZYZL;O6N*oE>$w#-9Xj)3c{_;?6pFVvMW1a1>4tCH>sS} zAP{sX!g!HwZ{V{3eLig`h5~3e;MX^VKNV6T7?~2CD=I&-h&t4}JD^f)6P4yyL1X)6 zuGzLqd)Qm?E@dVhEnl8)`5wDT02MfT6*gFv4FNMEP=Ni3uIcxA)B7$(U&WP{M8x^t zdW9HpvF0>V$rhCEVSfob^j1uwizz|^qzY9GYS_ z<0N*yQE(zlhXxM)LeC^gw@|>T1Zd{TBv+#Ibm+}x+M4iYx-v3N({YiD4EI4ppW3Gd zs-D;EFzn6|?L3gcTUkI%I6AQXA0t%#Xt&(Fd*4+Iw|u#{JYFUgG^ew<@W%EJ-2pwA z4dUOU@xCPv|7_F>2{1?}pt;wJ+GCQ}zWZ9E=TWSPi`Z0VfM?YA!3zU|VKA@$!kr!o zhSIpaFX(ySq36XrT1l{}nov5)&vwkGDhif9Yr;TAE%LG#GuyQ4Q4eEy(+@rCTH_<4 zQRzzrlLIPuA3_#1k9HVxyt9eMQAb^jLgX|~t6q;j39;W~Vw82O*e<^EL<1*^89S-b z7FuK+3#9#OcN0nzfD5XK(=hb@nkJH?cUE;{egKTpfEu5+)T?3D&I}xY1KXrE32wbV zR#MDGg{qq6V)iZJO@jnyy|R3XqM7MH`6V2cRrGay-4l6hg8Syj96HT^moV9e!fzde zN4KQu9Q6{kjnil{kFXq|&iVbt7T2HOIJrxbxG%K9H>UyoJ> zYhzI_*pZ$Y&L89=5wn`$cmc7?~YT)OKKwU`$NJ??4;2du}kjjV$84n`Gy=6AQ zpq;fLc0frkta~VPZVmVMI@V<|_Aepf#Po}mS(jp8w=f+qTHK>wS(%(1uhI;vs1{m; zvwmBvdO6SHeZb8iUDG|+DiwksFv~_~`h8Jx$k$qn8rNLqK*FQYQ$O2oV_78Ss|_0K zQtFE$&u@niDgk7O*=r>JR)IPD#TEE$I-TIOWdD%eplM1sr+}YY7-vIwrPolS(uJ{@ z4%3G@KEigp0e{LE^bk4fu$=)mK%QOct+gIMVtsMNYfhGN0GCj*wU{Ae6Y%Q(hl1gx z_0)2yfu+uVvBn=moV$%u?yHok+&1N=@IjdQAjyd*jCiViRZqixI9a`jgJ7T5_ly*^D?#Hhosh*qAP2>an| zycO1UMFc?|3C=GV-LVRgfM#VW#_Q?;`!03U6kwVI4!j=Qu>Eo=4*L0k5-Se8yQp_( zW`sOo{YE!Gu$4I9f+I7Vc)qaH4|D7+NRBn^eha}uOyCh%(EFSRQnUdB%9j)vM3 zHv3Y#2_aPXOvEeKVlfg+@T|L9MI_XGg7A6hbQcM4m;!y*F?WH<*F<5mUB+bd6*AT3 zjOI68JD00$Ek|vb^C7F5u@t?<`qTO&E(Jo0g*(d6nwwveZDg8xN_ibVUVRt;YZH#Y z`6uD=zPgNR^`|EsM$IVmpUuaI7hhEcsjh9aHuH8HR(za%mk~Xjo!M^cqsoX_6PDM? zF3p89%=a~G)|dEGCSRB9)rS#&K;S0kzl#K!6tG`yCJ$5qFbC=j8z~{%9OhQhLm(`; zICGBkHa4l~uY})KapSD4gcXp;O3}TCfe0dg?@DwtJnC53)5>uF_3^#-L1@&S%j7-t z=b~QIFoc51NynvY{5i&csaB&nafa8zhudZ{jMA2b*pX;s^F8RpocT zh(%Fe><4Qtjo2esFJd-}mrIMKrHW6D28KGKs*XC3S}Wg&XcR#zw#}&L5*3;!dat+J z@21&S+s{~%8j*5}m@j+X4x$FR7W*adPhYm54xS~uM?fID=+D+#Pnzyj^h8iM?IoAz zl#^n+-swl2(wA#D`g1-%R-P+*kWA`>5~W(^Hb$)aoG6by7OqLH&`YOGZ28>4F zcET66f*dD%q*fdJe>ivz@KG(cw%v#6V5|%@0XOes8y6djEp&@7GjxL_=K5k%j(kI%;#R`1T!-Hq@wH<67AGE$@z}h) zm>Z}j@$l-()V9li?*keCIXP<|nPHgBu1AM_K*~p3f)f_P^Yt;G-6AHSV6z#{oe38B z3Gct+e|D)=t%9c1MHCd;_!#ZRrWEDaSpzorv^-EF*UC0BC-Mn5PKn_Qb>Hq@dH(dg zqHY!bO8ArE#r-k=+SN-eKz@iLm1=fNv-6`O&N0!^co(dlF?6CyFYN(<+7`L!C30Jz}pYS z@66~KvaRs6W)bORCKHddu6zp|9n2W;j~Dbc64hwE`T>ClAVd({bUCHSj&K2xD5)V+8K$pK08Naxy=%;+Af^6yxixEL|1=FYlTf3qzooe;~EC zjb88=h=bL|SmJFBD@pC3G>RA5C;4Iw5SKJSToQFGXrt%gYu6`Um>Wm}Fb=PBCfZKEq!oA#3}_&E#5iR_`In>4_NO!vapQ|3O5d=EMG3 zl!7gh5`A<1l|!+MfZ&sReqUJsG*R_@A}`id9VZ4Mg#&*%&FSqQWD+xtG$0=fKBoyd zWoClaX}>Nie8SQnL({v5SJ@(3pJ+f$ssH?4d}0zxrCf+;wyEVF2?V01@KX1wBklQJ zT2M9Dxp#RPKS73I*ns}Dgu@El2UmZnN>(Ln!nqz_+ZokthG1M?BadxuzG9*zAN*4u zL0-zA(Tg`&jw<+i{XV|NHXv=YHJDD7<}~>kLf87{Ks>xhJJ5=5E~ua;hy`BKW44&= zEqn~m0yzt@J05gdy9?90zN`06*xvjdiJu|?O!;zpM{=I=?A&N}t*w6^N98~aC0aYK znJJnQ^EhNSsxIn0f&JV}4Cvz0`)&TEr)LrSe1AU1t+J~xWE1^eB;x9Y-!5H0H2zVF zbC81H8wBD@Wr(+3FGC=u69E(c>|LQ6H<>zXxo}M!Rwu?BpA^q=ecZ>iZ$%yPGFn^u zn!tBV;HY2TXG%@s_QfXBHaED?eCHA%ZuB(K?Rw^z$ik6rrC5EAA^womwrp*rl7A5R zh4=A!dZ6jz<7Hi-TiCNKz(V?GL#6I!OxDKgTaqF*VUshBg4WKBf&s~iU6hXp>u7jx z+%F_8o~I#xcP?iUcir5%JJcjAF;RyAY~H0Z6>HYaMCzOcR;m)emqTvHbUG4!9*v5% zphva|p`#=x+zs=^^I!%IkSx+P-Ej@Dk=D@z^ zZ;KW6H)zhz9bxLLA-;nyZ}i??+A!4?Xd2vCRIl0I+(h}h|IUi1gb#Q|ZUlBx0?D-F z+OvWCZ#erIm#=SG)PN1___%{l1jz)Tz+zVlX?L33lUZIna&NI<+Wtok4#VMY>-WrK zR{XBy8)>i+hEcUF32)xwjmmf4Qlc?a>>O$4yM5);h?93pj-aRA83k$K-Aq$`rA-tYP;mBw~ ze6dOiiX<)xIV{$H-t-eOZ`7f&jG}rPuob{&@clM8O|kpzZ7`^eAxj!61c;^AWGLHK zJZ{R~P1#yHq?=rmPP``sdY6}Q5HG~(wrR&m>kpEs&HskNshP@l$nxw4}z~M^_ETWzcPvDhgSfXVVUZ^7l zSLU7UO?L|9AYuOj-_p`NO8-re#+nlV9#AM+fz_oe?D+J%vjD{8p=OAv_7#eKQl*$l zU6$H!uN{Hq`kQZ_5iJ_k)`i)rNaS(0Q3orSlse^A#24is_DW+_7aE;Zoq)>}5$9)$ z5BaO;H;tUbBXBtTbT{oUSkJ#DNMm6pNFq-6JOpm*A)elhcDk+;K*WhR7!q$-KQHQi zZ5;L3hpLF!O@wq)WcNJSE0n4!8ukF_3&DW*`twqu$E@)ajNPz$(`F)^Ub2DEk=KNO zL3yG;@^$(I@UG-8-4GEyD$|Wgl0lG~C_s1`?Jj}4hG3OTxlB3{9Q8NPo}Wl@j}R1FqlVo2AeeFuKjh42}DtwBF!Vut-DO@v*hcB-3SKW9X&4J`3w z>^5SJ7@PAO-_2y_&s+K@o<^0L&CCa90JArzT#dffU@an6Y6Y5jxUX(v{=quiw4J=C5RJ)Xu9^aiRB{C z_64XDDc#h&65qW_B25grZudMt`SH@*{*gc-r5kDh_|l>Tvg{YlXr2dbWrup8!lZtU z6pKtfEuN!OoOjNUr5rW0XURrxH`??Py2I*a?fnt- zM2{!BHw?l)V|Pv829l=N>aBq4%S>r?i_2W#ofp2t?>XJ=pOMzV`D$^^aUZ092%>+B z_+$UjU@@I|Y@#v}#3_m}QzQF(f23g675yWeXh>IgAEgZ;CO8MP;_*d1QheYV$r!R< zRFHiT9=U6x3dK{&j6({@`IaM?J7s04MCxpGkAqo`9#9Rvf8juqm>RU9Fdb-#bHdfa z`LieE;*SS6u_uCy>(S#xc^5WDEVTx1YTR_JqEBzS+Vy60<-~WvDE25vP>V-G}N}7f~ z-|4)wUt00-r>lPzGm-QnzP@Wf4knh)G8Y$t(hY^qfFqMO(?;Z_XIN=6M~c^0`88g2 zc~Ca8rOuA`#OHP?U!zN_Fl50((%Mfnwjkf^&59M2nxKzv&?}R5<>L};Gjw9eMfqf0 zCwa2_qp2~B=}7SS{=z${cWxvj*s5bpr!03gVmRaZ&;jn>D1|v?k@6qkl}Hw=!w9yK z>TDg{PlE(_z1`5iP{j7pq~@d> zQ$mqh;*h!}$C82(CRFGEyU4aRCE~Y+>JHnM80@1@wbczDewjVXF!H}qzV#fA6$(IC zmiO!ydQ@=Uy5K%ac09=nsQ$&XyUE~!)6eHU-^;NAug|DDosLN462MFzXjcoM%%1sD zX$+W-j)osupgtWs*DQqp()^mc{`QnhIx1&m7ntd&Hm4c_I%rRsRIqn=$sWyg-|su7 z%uyW2#!w!)!*!d1r)1Ll0i|m|alA|bc~xwHcl!CM6((f-!>1vu4R~N`eAOemrhGSV z&!sU#i@YMPntJ~2siHgDOL4T;xWK+zEGm?9L`Z*55ETFJ~NYwEPgWAVGl zmprddp{6?BJsiuYsg+k|)V#(@B}e>_rRn#wN4xSj-ucik{)n(@Tg2vIk zqgetGfX{k2h*>ob5Tm#D!F?at#eJ?vyc?0co`Q7da^vM2xuE8f*%m{p?R8X{a!X6? zJ=p~kG)Q#*T*KOYZppjgnKn%73$YWJ=L64||39|)fuo2Vd43n44*9~k>1{(_+;?Z_ zOpXhzKf6%V3Q!e@5ZlHShC?aYK>{ET3*}%K=pksTpd_;# z_woRX=6_}z#6JJ^e`Oj3pMq<$&-Yl%wPlRpfe=Ez2gBKA?0F$|(9MVES+@h5me?PJ zL(U7(KiRr3=NJe;M^RhM;7Wz_e)UpK&(SKG#AneT5(s?cKX(vn^GD5Q8SnbvsYg2FV3)WC6v$Y|l~4|}e!4RsWf z1_nbd1wiD+9}EDx6G1_yF9yck91P@%?=6(VQ51NwJ}ME1Ie(qQ(5uLRjNv3de%ma= z|KTDk4CEjA*)DDI0kgp7bvN$?Tsm0a`g5krG2dTj9Tq6N>EH|s$#j}(P(&ok40LpQ z>L)xbJ2yEcB@*5hXyy(x@83RV*{=`4GjnGUa=}j;m;xVJ(f6^DqBy2sjdSIbC6Q_| z`xe|fUIk2&k0*o*Gj{Zwy z1KW>?+D!44KD^@!zmHY1+$DshPefqIZ$cm$%t)l38!wCVjGWNH|HE>VBE+X9cdVh zDGj~*XVRmX@il1SJr-g1it<6@j+er~ur6)Jsj2E7{!|d}Gg={(h5QVTGk!R>|vK79hU;B}7Pl56i1A3sa?kJGzkxM4h9 zeaIlku0q-qo%&EK&x~pFlL(xILTtw3WkyI4B77V=S|o0sP+e?)^vA>)x8GxtT~5TO z+p(LbS&bqubylpF1>+F-Dt4-tzE^U+4|bPgOUPQ^!_Don$6M2xy*;(WZYl!ai1&{_ zycI4!+0bv$?#t1*8!~J3DIRwD<2{kj7V}xy(d&KA)SuakzQY$L0W!G2X`J261W7r0 z+V5>c<;v+q=W^E6;^lZF->&!%OR?9Qim*RA5V1A|8Mja?iWZNQw2%@|%?dg!N6 zl`;YX+op}KkFE`m7}MHDu5GlLV{<59*1wlysK3Y%k?L1nlkOS8^PHaK9eFsp=VIV*^mYlmx zp7=gQ2swO~ssgli@}^>B3b6J0W?X|b^iH#H`D8Cv&&W94R+$(4D^4>46|4p_aal%t zR!7oPOM3h1a6JlbAJNjeBsE*1xH{MFxUfw&!Y>Ei(?DzE)CNhvPjrbmEt7d$Kl;B} zFL`)lb$%;GdT8(@BLo|*s6Wk+vY9}>8{3or1zmI01p@0Urm424<|}XHNC@akKq?85 zpp?9^4QuwaO5f;|cchbSoOK@K#mRMi*o_o$uD=ME>;Lo}%0U#oc$7Wsi8{Vjy7A3k z@D)#-U*=gj(*KE76pE1hKMmEfQ|=bN;7D)Mp3PR{!S-%ye>2fWwL%`E3HNvDBELcX z^7>6PW2@%2L={T(BbB9-eHj!ndnbn%Ms`8yRKD+o@>KUTYrTRfT zI6e9OP;=jouF$1@_`cihBpMnA)vGr=27x=&(H+PMzt1xVn0Kda!s=;@&8NPA^lC`F zd=}{sBmUQ;!=P!&=XskoRb#}dJW(~*=pZ1goFtD?%i09^xfDugX<&Lz=U6n_SH00t z{k;bp{R>vniw*csy8ztA>&$HhbPQz$cbqB$Yu$#d&2qQzZ_4E#I& zg1kXFub)R~M$N=V>@fbl1SR(skd|wJieDxnE zTh}kAjdHx`-iW4qsg;~A5_)->3V5;r1`?iPutmA`hZ#ADfCL|?JP>(FZ{IfMOEmoB zC1w)ttXP63*Y3CQYoT7Xc_7dt$j1`~WS()o8wN0O;(t#470?4K6ZiSY#k=RHE^H;zbLkMvGv1@Mdv_{mkDE$; z^#POcBav)CNCNBw{p}KhVUNp{i>~uaQhIKe!M`d45u%5|#Rmz0fLmk+}a?ZNTq2*aAKVWP%#h<28JZ&Vs<0sMdL;YA0?DdqM8pDeKD&|Nhp}6;beg$c|T_EHV0#2 ziTYpb^`>U#C#R0s19{MeoTIM^JT%f+9RI9?k2CN1d~hp?fWT86l6)+!)d1~4?C{UPd=R|&>AhonD}hohZ852o!-nx zGg^TGndPSUd=%a$BD#3bV3f{|zUtSGGT+pb#+bgPLrG&uz|YCYJJV6`bQ<>tp245XQ`S&{O2N4yOd8J6(SJEQ|fcD><5SGX5({xHb`?2^Rl ziEGL5o}EC9Yq}djn6cI$&^g@ml1mLk9#BrucvWLvk+Kn|YpL_Q$_m;pt_TxjUUwCcyJCGlC2l9BWoq8{``2!4Wo zMC=YP^p|V8sX0Md07{M8EJB)z<`Ltb7zbRs-XyVtGdAjvkMJ~rD@7?@o)!W@Dm>e% z0m73$SzMeo)o%&LM~WP^PSuWVfU4f)-a9}zKOp1u$W>-I#jrz;)O{-h_e|3KscsUE zmi)`DM;e-5jw7DC@Y#UbQ+mby4Hd28W{~c8t1}7&*bt~r9u^QF{Db%rt{0t^snbyW zxgqMuIKFGQ;7Xy=5(R#4?&{ugVW}_NpbI1upQ9tkKkn~I zd|AebCa+*^@%IL{*58l0QAG2}HWWE%WKujbx(JYmpt?g9H4UTpaCB!cJ_u=Tb-vl8}-TRZVr&%0u13o*6J`7i}#D}E15K*h|8 zy#C*{INuaUD$>`A|LuN}zrpt5VgnYMBUUiq&(rrNYOD!Au!6QD^s3nNITmm&C$yg8 z1+}5luBJwf_-T(OB%sP5P)Dz}Z@fZIy574}f?g`@S#U02*)pIdiZx+fO@fh2h_CVR zzU9P8PN#{*F#KSvo3S?W@Nmk1|1Qr6n1mMW7xCi|re(n&X{kB6YKRR+M|y!mP{F^< z4?NK|4GZJ52i^)g${!c~LqPr0QE2YrIi%M#(7&48?N-fTf)Hqr^_YSx;TTH~*^T4#+7GUu;Wn*2{Dy@Tbe@sNnH8 z8QMZYc5@1qX6>z~AF zsH*?uJ{+A2?x8i~Y=jTZQdEnWKk?6q>Fb$wP?8PdYY>@!+~hm<@^kL&H{(j%!$%Q~ z>4|7gPF=-MW@q;@%ctj0CR=($XWGz8eQr6kB;X<9Oi3?FiCNE2B58CT2mL-i%h{nm z@NR0}`aD=F%^H_u45D>2T!YbngiFiLIpKx17`nS&1ksM72k6m%3G~~$tK_n zyUJJcwMb*-U@N574`%#v%+GjNOvqB9b9=&TqgsU>{Jte%UU$2f=%3!T!wn(34PH9f zdJCM$q|HXIWibwr#t&X?uGCNP#~k$FTpr(OQTz`JMqQMbBHgmmJYz!Z(=(XM34beG zIP0AZJM72+buf{|-4%B*hCnFuW9(L_%i+#mfpyos$TDRT)$`(x1QH$21B*1R!trCk z6KqwPQJu&5!gR2eImQoE+OJI#RUGr7H>Y#x+c8K$?F$Q$^FgsNXV}}zC=N)nUtV%q z*wjEbBFkt-r)pmu1kz0QeQSvVzGWcdjr>coz0~-OFz{13sGB*T$w+5|cLu&dO)js0 za@UcGkK|!asRvN)B~6Y52#CHf+w#`AQk046cR-Tn%VVe>{6X~oP))EM7R{ajZ^>~- zYgi0FT)9q18+R)2wlF$D?2S{+a1ytlTg@;uZRuL>C+vp|cDb2$+`Z)NhiJlzb&Y|= zpJoZj{umTAAPP#r4YXI>0f84px?-fIKGElje&FI$c|F&;89wMo`PE&H=m$FVEBytW z@YwC|XXy+1OH9oH+W%ee1JiJw!?lx3ofEQV^7%71+DaUiRBmAA434Qi`4%ye*~O0 z+cQ_G#;QaE;;XX{$xX|8^d}2p=sOYj#t)g%@eNfaGazcvpu;5J7%8Oy9PNvoDE$b_lqujIO@`V^7xQ=nTgvZ^V7k%_uepq6kcD+w{!W$4=MwnE=Jb zQJryOkWZY^0Tt39_#|Vty0>UY{^))WtmqC)wqL7I|6A0EzX zJ;${x{yZ$;hTWFNALVu*Xt}Z+!@-AjTB1WS;(NPPqGl%XfCH@!B!jap%T|h*HcN{? zgVOkahP;PZ#Pnv;QGQ?(>dZEzDb4^S$n`Xo*B5xm_i?8G)yw;Iiz7ur`-^~FsZRM` zo7=A)4hdt(e0FLeYP_||s6+W}9t{F~>Hi?j%BvNG;s(qD zqngI@AFVK#wbSBq(mW3nze7uu@{$#BG3~JJ}M^1FPiZIwIvc-DU%}@ z{Oz=^shFpJmgQB%lBeDQ+N(?c-a(FlxVXW9?3h&D6%k0M!FuIR@mH3UB*wW1mApXC z85_Q|7!|vi8i5Ci#b4fQDp*fzc}s=Mne6T6w6(pOTVA<=XukwqcPd{Z4&X% z5Fv<<(2tVUIczXMJlXbIuvrj0_JIs@uaTzT_l^COIY^b|2m-eQo22191o9|HPuzx! zTBkrp!;KyoxWYkh*m6z1`ud=%%y)cvh;ZQ7N1~sUup`aBp`NXRjmEm6?hJsk#PPoi z+*1{f2AL{Q1L*QzqM#dANz0L*quC)NeI)I(H>!IPGtI*RuZS4a3;d2wf~y2&o12}m zS^kz%#eOg9bT|kO7F(}0sGmLyRu&$!H;C~yHx39d`Vz7OZfkQj0Ay4J!)-hp)P~l! zNQkw2$}M_WE0+3J!$U*IWW@(|5%?|piXJW=NQ=&ipgE0yy2h(NNvOdH z-`U>h-2;6LV^|0A?X3H0$`tOC_=^sTDPke8eMAhj(^cD3?S$YU zz>(_l;C5^c9G$mG9*A8(ZUdj>GEn~P!3UM?>Zhl5A6Mfh84iDS%&!|2jGMH#C$*W5 zrN~3~Sg2;yR98l4KvEmc`JjyHSi6u87Qy?7o-JIwAT|jqx3LFLUUEX0uMh}LXX;$W z?%VlRV0f6Y+b)NG^wRSxgQ4~H%Tpir_RhzqceA)3?`0H&wgv~1EEN9eme^n%lfs|e z=AF+9GW1q@GEGaJgGBL-gQvvXubzF6PK+gb&}m(NvrxXu$y?agK~;c+Y}o@!wCF z>)N_pS}1X&v&Fn%pe+P|!N_~~CLZ2f>P!i1t))h8YxnfZ1EhA=T3V1lylu9QOonpz zD$g%$gXQ4gqQ(V*kcG1TT>!9JtI3zy0DemzPtOI_(r7!dV~!SO0h~A7la3 znvqrm_~GgkQ)B$?Kbz?*;DWsq4IAA0)%ct9<6(MqgEOHF+wbV+Bz#he6#gEJWF~@sx{pdhbjPLi9(GXpgc@a4h$ZhEEvs)9Msj1yDzj#}>u5!nt+ zm8`2frnvOF=J>4d?P$w57?lHRssj7@ZrK}m@aw)CNi!#wMOzONWWUNE{tCP-%2kd? zPtn|eR_|h^g%16e@2sTlpS->;^T?CEL#x2M`>ib-H2$u*qm;fJbzR9Y`L3=LKKth) z*`HqYL`oGzN+7-p589Q0Y?Y)Ka6(gpq%m+u0Cm=5tP$%Ot2K}Z5J8%AlY7!hkY;2*WehW^O|tUm zcvuW2PITp+8JDx`jPuQ>Cj=$J$pJ~5xXe86m60Smv&~3FMK6H^6#$-#r%2$cBm3Ko zq!BS}DSyMxZsv{*cyqVjE>M{gP4W|qB#(r5*rwCoW?z^z&*h|#)T`sD1KlCC&b)5v zM4`2^c5@zc9*`#I)u%^&t(@%QR7v2l4DGOeHl2lS>*pgJR$>$c)S!@7a%Lo-Nf(Ig zGn$Oq`mg~BuL{OIu%{#hzQkx%9nJCV`d#lZ@x9EAI{8%o8WuF;sk4cI{RmIpfcJGH z8v>a;1N~CvrO)(fOhAC?z`sq$yB`YG+;cs~#J!7=ibZvRA8smAOYs`4n5n-2a;E|` zdI|*EDG;qfWsV*ae~npGF8v(Y%?pjWT0w`ieTbby5-}I&wq3u5^|fLnztHph-d4}IAj62? zOiOIJXEe6KXlX1r+r__$2msM);W!@@xo&JBp?1)P=Hx~NWRvLKG^h?*=wwi20QAd2 zYF{he?pIrSGP@a&);pcPM5i9=0qv;; zQzT@yMdx`4UoSV-LC>#EV=3S19Nb_E*Wc>v@Yx2`9XT*GRPW`W`|$R=3bM0XCQsf~@RRr@lEnYQqV-clC4;{+CGbEQ?wG&#jm*9OX*D8|g+unARzMapZMuB-2xdw~wG_RGH-+*(yIAk6OzU!{&uO0y%t_ws=ArSDpqpK#ted7U*V@vq=W~&>ioZFZD>%!R{6OsiISfuODTW{$>y17ztlQDUAR_e*# z9UPR&Dj4k8nS%l%5IjfS1{@`RbR{SveCqufYV&EcT}(KPtQQa?Uu=8$ zAGm=#3Y5`Ide@g-7MxJVOxO4R=`ip>o6#~q`R@#$FMGtg+_^{vv0PMOy4gTzwZwc) zq|Z}qz?(PN*-{9^Xw>-KUDWmao(XP?}kiWVxt?bH1E z<4%OTkgJqjp$M#^_g8gp@SA+pnyvN~vc(MnHR_V2HX>k{pPnD~W+CFJ$>l@C)`FMQ zT}xVRv$U|$k^91}RoU-E_hyiA_Kg5r09DQM`=GVbecgpPt1?x>N^)x#1oEEu4kKSW zYTi6U_Bp$8_SkNzkg2Q!7u1o3Shc+m%TUA|ig#SQJbqC5WQ@s%jGh(=1~Qe{?Dp5T z3xS*>jZ)nVe4OHPu%fPkRiO~Wi0+7Jt|Nf`oa5X|onP`LxI_2b)w#_wDL8kbuX(~h zS4x73x57KE$Kz!ZR1v;Ap37%9>0V3VdV@e7yZ>g`qOEK5s?>dXPGJj(=Lef9axnIL zRJlI!y;es=RLf!|617M($HgU1w^2jH3z6&W!esWcjrFWuindOz!tVWo4B4@zYmZ7# z5>RLzW|YfoJ_cK@EcyX%vIZUK_@Q8W1S9}7%K$yoO0IY6^vcWLe?JMbHgIn`8EoPW zlQTR$7)vt@=sYlICVRiyl&e^V^ZVsl<7q?C0dM;|n}@~fe>*)D$R#=Gjlx-(iJ;gH zm}}lqvdfVmplm`PXpz461LY|Sid1qtHhfckJMj8DN5TU%Q}J;95?Tk|lpEFx@q(JU zykY2140U2jIXqzf|`Fdw|OB?pSzpZ_JUi*7DlR>*P zfN{1rm^~l{w|Aa8UQGw+h{!fvRrAYFVuNYyGqZ#bcC(LlPMi-H4P_!n`gV9A5NO^p zTHu8^W`zjdV<{3|$uY09kCb^gD!5B9Q>*EN!sS4V^OQ+Yx;w-HfamDdRBT>7FL`1w2re)?t4QrlcGY@NlnnGP$fWPcC|*0CN5W)aI_7W|5mc6A1>`h65A=XK99^t&=QQpRb_S<&5Y$?=pW&e&9C` z$jMuo0Dgb|GwY8oEx$fzsqT;HG` zW10WPbjm#D@K@i^@idU}##|UlgH6ZzC{plYk&Yg-F{Qq#4VqF>$G4Re4F`GP?g4#1 zg_`~DX1}f@+bS^Bv`{xll*&{arSsdAnxBBDrf$N;c89LT?XJIL%N=%p@?iG2N*~WY z<{q<5w>_F>-cQQfcRTZbL@r}fctpeWM5pwm4NmyLcZRdJMtzz(mqP|9e@gmoQi4oT zqwSRcA~S$)MIcRQwz8+?d)`f(-t`*EUB3oe=YlA}GJyOrkUL;N^1pK|7Cq>;uWS;W z$e4U=MRa$;)wIdhdKnf^Irrv@n%>^Pa6=eyeknf+8SL>1mLB#Q%ehYujnDzP!^loy zka|mCVa7?ijtYdW6qny>`73}(fdbXV*-HK2wdE?pPiO(!b&{n^n0#CZEx5!dg5v>b zs|O4&^n>s&cPLxCN&Ct_xDMjBB+Hj$D|XoSO=+pYpjXNkF&^wkUD!~~pjy8`na$&` zFc6Txx~#Tw06)&dD*8<%7-F4s{3oe(8o|P`b(BfQxrK-Q795>`w(F8-3;$u+J=JW` z4L0RV_mWPQz4a$Hm7G_?01prVZ4}rhD5Y{c0Pr|cp7ZIw zmVqGy=s)j_eY$pTOFLnGD%^#56OYIXzXBgCP02I(0Cvp1O2LPbi94wKnG9}+>MV&V zKB@{i&pj80(yO4CB%uE=i-#BdKtl{;iut(m zxc#=>-I5l^A4Rv8GSFwXk1Fn6^i|a~?&Zm#NZW}sSL1DhLH_P@GHNjB5MJ~7pDt%b zF8_jIT0VD@B%LR*T$$k7>J8l08B8lr2q#>yTU6#>a6W=PH>~pe;sgSGmNz%5y4<$t zK!ute#(D2UPQp&FePyKZK9cUzeK%IyiBwgwW~WMYoAkO&H{?R-Ts$!u>y-WxFI{He z_mBI$Cgq=a3COpcM8BJ zK0(jJ^JUWKLau$XhF&OriYl4yKvI@+;QDpCN2E4pZQR~fNO$vyc;wZs3pYM6Rlq<- zTJDNm*?SOqBE%n!y+iu3Kj&RR4Ww=pP)3qAd6I#W6agYkU_b$>U|goi!zTcqNi>l7 zq|a%&8|SX~`{+IR2=pe5hSotjderzXRQ`rN2YbP)cZ1?YtL##3d(yx)GbKSAv?^;( zX`@kIO9R4jIY<$J{bs)lGX+dIBmfvM0m_EnSTAD}qDmeGLn;n|Yx2U-zu1m&`7E>qy+klmZQI*9ap=hWfC>OI5BpVOa0);I6qw;4 z#d)ufc2``$)6D44<^3PD$i9K2@RcmPoVR1=6#Rqby&p| ze_dpNj}|H&WqgAH2NAA2Vo!B~YK5nZicY^b-hiVCkVU$0*kBCrFYyHg?UKO#mLnV7 zZ&uhz>{e9u=_!+!vnW;eG0zx4FkU{+CWNh`jTGCTIziXe{Xh-op8iAn%lre5FBWsF z!O+J{ZpdzTPvW9V`sXaT74cyWduppWYpR+0r};m8j_e1}xVwg|Q!1O@g%8XD1a!yH z{dBg_g7q{)c$sm}NT(0c$i@t2WRW;4Zih^^w09dNKNDUJ879CTe_raLf4wIVxgK2C zuT_MI6Tv~MKX9ioMYaB!?!pu3{qLsFfUcMPd)7h^E6}fr2LFVnuQ6&-A1sXEwu%2B zA|k%bTz(<55rQwGYPBwy!3+e)1Ih(x$d?lywG~}hLk2rRv;7J>pclNbVgk=-ppAS$ zZda=cBKx}f@TeHtdHc&JM+7Ln`}d`l`l0oH3x5vUG~W6gxOKrah(AeRtw;Z*bOzs} zcPC=Tnqg0hk)#RytK%ZaT0$#B3f0QCJqUvb#~`tCnZmufEfPb3ZcFJA~ETT}FAsjx-f+5eflTNCleCJOz# zzb|sCi#G_S5)v@y2a9ouiFlyny+rX;XIB^8!oq^(udmQ|wEvp_?+^RupK)o#s@NNA zYbFypGGI)U7<8ariqdDU8!vC~T;n!B=oAKMfALS4pgR1oudRgV<|v(Dy8Q|mf6)$l z#$$rU#?d{nU^oDn`*nG9(=|W;|FHMoQBj^kNx>4y>s?@Qf^hlMiQbu|&12fFIZic)`zTdajcYbS~b^bcDUhzdh#ocQo(`wIBr)c(&gkzPmqb}CzL`$xlX4E>2yK0-?rr#>z`U{$vQeQX7OUA1{xT*l|p z-+tP5z_$5&ocUg)Bvx<(W9EVGBU9+^lWJMZJ-Y7idyYShPti^`aIK6wj05#a=bY2Q z;@|Avy}NephQAgd`TzHilW-HYId%`}bVU}MonWKz?}-nFCy%B4yV)fAD4`{YrfWoB zpEcT!3SRAri&?#K+eR(Du<-D$_fJ}O=#f#Kot+tUkFY@VfGN!O*6)7*GvC#Ld^nE# ze~FHGm#2o#lU7}r*`b(1>|KRf?wQKHV% zuNLzP{A6!0jxI8nRranYZACZBI6g=JkNxQTX_;qBj;TkhP_$E7WYd^&_~gkYXi5{n zIGa#=wKpkys6-@%Gy*%~e&ND4(rE~7s4k!v2x%+Pcb7Wfm@2Pe<-f+!QvLDLmuQ0d z>389X=kFh%IR8|qZ(@J&>1UTeeYMZv8EGx@7#(eNo?iF{?cN8O)9TTua~RS*AH!W* zpEk9oo_; z_(?Ycv>1F|U7a5NGU4F+_wRYG_133oH-!mOS^E{(W9nh@?(5NY<*EKJLH2RaFUE9}Bb@qmF&?s&QO}AcER`Q~$dwgN&+^GTcSM*3GrRcjGXv1** z@#F32RJur5IOdek6(^@R=+gJ?>ebiKWmkUf(wDFW3wf{s^~@vD4W**zDq1~VWp)OU z{gk)kC-ihm#hOANjK$u3V!xi=`3g9$A5mOFm= za{;F2r*r4Fklq)t0~dN~;?P=Ba7$1r>8I8d-h|eV=pp}8icvx8xffrdf3RcUhi}F5 zhrViVX?cndJb)fBCHti=RlT-#S5x~AonG!mL>P5S|Nh%R%r4E^S^ANmYPSoQotd-T zuURu;zkRv?d{X+S%=zghCO^=jNDEuvN$I<3}j?i?W3hZald+@!hD;$Y|%bhEZs zqA4sPot*29+*fY9zXVMyZqey<#S>N$t+{D4+5 z_?k23#x5={WYS0zDxPy6pZ^Rm)M@xO;@-V8=BIf+c6YnhJFNM(g?Bh4AO8Q#{SN!B zkF$<$aT6LIbb#4j*v430yQ;8qYrMQiin_F9rrYY+7wc{e=8vv=qPClKLmcaCSOW|2 zCey0EZ=%3$d=GZ=Z8Q)%2jhXx%~u)WE|z;T-0g=yA0sOYy~EKpW8>_23bmzHZsn$f zU-Mw=1UA{BNA&|V=?x1D`x9niuWnk$&+Nvz!*}OR>-*~|7?^{0t^DNcBr4^uqv7IK z^atX>O8L~)wFq{)z{nT9biRxu!-rO3r=nh>5znva5;>D)zopsz@>^c;Od%sgCptP< zFRoKfcf-Dm3VQO%adh!@#XI=a)*L%__|F3dYu2q>x7y&)u>}p^K4=qEIKAb1TDg1) z+ly{A>eiK8kyF7mdX?bXHI7dbl? zvq8fAbRoOnPTN8^!4zAm!F{j*?TJ3HJ}y>#XBK3bRbqmrD28S`RN7vbJc2B`?j2o4 z!Nj1;ivMQk4_o0Lit&1Njss(?w_c3T(DD7A`nk*EIdmyLOLz&NK9xQ00E@vGDHTIM zxO_iyjWC{9x+=m$;XR^UN4vUP3*6D@N4q7;BQ7C1caOHAljbVvj(brrvtfWN&|#?( z9^d~Juc924-D|AAYi%LA^BB2{R?3Z@UZuzys!7t%9$}#~=7Tk=a{Dop&Iel@tE~ZF zN)2YvfV}hhLf+Ff7TODyCFc$3HPmI zeOoKAdS4uRcU?m3y!TysoY@(N7vJz{2#B?ojD@zb%sX??$xs&E2B|bt7)6X_nfTy0 ztbRLcD7r?s#i0W_8cl}_&JNwutZ+eFa4j$Nf$H$U?<(VBB0`d~8uy6mXO&^hgU}#( zAO8L%zRKb4W*3bO2UH&XuoYIj#L%TBKz_2>A$VfdW*O5DsplJ98R2;AgLRyrxx;*B zzl~EqpP81@TgM63>sU~9mC!XE=R-J^R@-!b)6IX;#>MK$||_yJN(k(Vg+Mr z3e#U6W8kC!%_q{%;Q_yTc?>D*e8g;Z(EDx&V<$kC4 zx;=PU<}R@`f~BPip{0BBr&BzII-htwt#ZHKMHw`f&xuyLMys)W=_G0R70_{TRr;8nOz<5Qz_l{+z&sOG!n3U3Cu>t=?6a?KD31U zYVAi6zSXwrvTq^xLRnYinkL z)yIl(*_4q?^zkyh9>9H_$6HhX&{UniUDTOfPaZwG5RfdkcA3?)ZQS2JZ2ama+SZu? z&gFE4SXH6zZb0tDyY0SZcQ#1G$BEB>wqfo1)wVgEfv$Z|-&Ub#SRb7;E8=8>js}vN z-*Dd(m6Vvf2wk$0{DdkMW~`l*=SL&)iRH>AB763Xlu1-3j%6lHx6WViM{muh^^al6 zvMcYFxAWD+(@ z|9HDBREi$0bY9ZOW45z-%y(v#vu|2iCc`i#x^w~0wT({w{j(^v(ykHPZJLT zX1zp%WwnoegKJETQHzYE;deaG>-BuB_fMmDS|`?H+0IaF8hR_npt*O$s~8eEx-Tayx)iQ$<9KAy+a-(Z#VhZ+?O@vpYE+x? z@OboNn_Vs55%dH$HV3dYP)5;XYG`brMK;L5o?_(1>_n?H5r?j_;!Sq>_laF^uudK8_*fMki#B~ zv8asbhG}&QSiJXqS?C(e+^^${d)s$LEB88C<*wl2iTyUEfT^V%bV%~S)IbM%2a2J; zdQ!I<4^LrMyyzAZ+or!*wCM@OZ)X4lUMm+3x<%1J*L?u(xXB)lvr>^CuL*PQ3TC!( z*lny&ywrEfQY{*WoePv6NVU~HvPzJ4V9Z&@GbUZ-oQ!CthaqLb!dt%P(b)_u<2CmF z1R0n>87cdYcWam3G{DG}-ni}!Cw{EEjZQ^ZUM<6@1(E7!&w84qG(FOFZZ=Q1ZpgEc zarv;N_kdls`6gD0r&biDb5-T^yWgF{!R z&+rwXyl~g)Go375x0SSyNgAI z;}t=79g3x6%wp_pja*sH-YH#g>4jJ06gYhxx-Q+RbX~kvFZ-@;&L-WZ>*9*223YyG z6JjggR0MtSyp@zbZoVnXB3q@q@u^LZe~@-oQqRV7k8Q(8TCPu&m~V2alqnB2pJl0z z*2(&f44J5!ee`@K&Z+KP{3yK1?%p8(GmEU|3=6%aYB`Fy)3mUwNpsVN=@g5R=Z{sj zCX%rDRwi;J^ltFN&fid@fgF6oTTg7fL-*Sd^l(*lU4q+`+&c(v7@HICT3*k-Qgs8XPrI#u>Y|oSK6ev zwB$MJn6!INK9qN89Bxxe(lZNSV9YGqSLT_LpGd>3Qq{^?lW3~AL$H;vf;BzN<8)+H zvNCOYT+f_cCt96Bc^viC%jxl+t~+b!*+rI%c|In@z~_}8Y}(7K$hu4%7NW8J6kypP z0>02!)zJ3jvy1N^-4Rq&5Y@|wO|H7O@rTa#ORk@T&$_yd+|s3w=Qdc>_p7DYh6@$% z+~ib>JF7IdEtD;`P+c`4-~X)S^5y0mE)pG2?|0YOUhgX?mGTpkwaYuEGS&KmNu|!b zT~_}?zhAKH#GCq&fhWA;o<(IfDcVv^)`eHxY(7X+)QpT|g$0`|au%FVa0*Eo+Lpty zxn31}{<>L3eVlrP;AW?joiTQt@v-G128uj9L*r-fcl1xTRwmfpnOi-j%5KRcI-8t& z{i{`{j5=j5G{1I{7#p^z%GdSdPi3r%(~nWX>UdBXBM@Y4bh%TlULa$#g4uATEl$pB zG%0&@jA~R|(;fNb?5jnWnxZuWWOF^HyMMV8I%P}jYF4Y`W9vVoYxmUi@;uj8&HU!Z zTYd+efU<$6t|Gl-ven(Lq@-lzx*%%{!Nm?;!4_Sx6sMB(G6Qo4sIv^}2%_UIz`+y@ z?KU@R$q8u&l|$S5>V)~HN5=vfi+SC}>t{bDmbCY0wPuKKk((H@)3l{aI820v4V5k* zEiTig&~D#8MkawKbb>$MW%l6Z%jV?~=|?j8-B`h;VKbx3E}a#01OHVvO+v+dw#yqQ zGX$frJpR>nkj*Q>_S5ZdNNTsAX&wFZw|J`$cJ(F4JZUVOj~~~?Z^mYIV}7a>i}hIi zdFcs(e7yyq7Pt@S5gH(59Tv#6qcX~4N|AhswiZs(oBPFvQO9F?H}pVw-csp ztVsz{r%tt99-3|%D38?scr;MLqP(J_ZT75!Vd@sj2*aByN=#iTz|Gay-=66Bb(8J8 zg}ie2HoG+M2UAwQFRzaVvFMNV>DrXE_s2JNg1=HhlUG2xaPTqv7A{)x4kT*?A`?cL zL=c;o`MkdJm9M;ECzF#ex3fCB@??6nkDriJiQmeEG7nj(Aw;$Z$r<KxF3T186 zbvh)GDl|Vo5tc)|n{C0gOZT771!a+(OJph87BK!`^>&{vEz9{4+x}d?25b|p z*`ZP%f!5Ix;Lr~JwrMS_r!+)FPrxd;qK~)3ba#wcEghjGg7ylhdcz>mtPhW5lonck zfbonz`hQc)yfR%E;7CRbC)>k#k+&G62DTqTmu}dmr&psYu zLuo8i`jw>)c2U)2O}p9*Cu>85yoK$oGlNZ`Mf+?Lt(_cgx>XS(?-JK;t{gZ(kKdUu zD=i#3K9fJz_`$5PLYh5w!1|2hn$}?dgz2`YB@P~(RJI76+S@-VexYMDL2fA;#?r58W=S$L73Y54l`Kui5+m(LDt9?`LV*58`1 z*x6n(s>qzpPCp;Cs#YCP-I!84{q?5{ic>Fm=L`W_1D7H)4Fzm4gTsl_PB99RvWo!C zVv6}x1)ipR{P58#={(yl=ZloEMN`s23auYLa{eSO&745rZh#mk%0gEV^b z$ic(Y4oIG~yO-~Cz&zbK8HtrYCv(7eiZad#AL}5&9>Uh1KH!f}P zmmOJ)L{kdJs1rH;k-VKFy|zjt4ZT}ty~+n^I5o@+zR*U(yzfTOEkVm@FLRCA10{0omm&e4(UlK!Y9v?o)a&-xlay$!8Q$M$?M}LuN63Cpe&9lo@f9ho{c#*vch8SJkHL zzcp|h=+>R&WIZEh)S<{V$gJ;%OHj{b1$>y*L|Dm|)PBbLxJX4Wljhg8@m4VA%=&6E z)6jA$6&i7H82PZb$7h9CZmLl&aC7n#dMhxS#P2?p9nWuqiRub|v2bMi!&2#y7{=b} z)+m!E>q)BW>7B(kldFo+En7RmmWStI27-o;fyK8N_z`pRdbMwJpqu8aSnbwWmRGa> zppdLr`|={z;N)mmQlsr|0CRcGUfCrJ<_vY2u@_a#?8m9kB6b}gs;5QQtlAkV>Ly5? zzFE>P>Na_@x`^X>t4Kg;uPvi2Jm`KU-N2OvNb*kN`eF^Q%uAWJW=%Wg^P*j=i%QG1 zVOv`QL-%^xuN-7C*{>!NSYAbCx;eeG=dGM!CO%r(0WbyXRa(E#d#l#3lNoo{a}% zBds_?O)FfjO=Ys%&c$6D%A&3+PHb?{ioGlE%sM{pVKvEObnbT@FV7w+zN34o?#Y#` zCQCoZfz-oh^x$-JMoy4Ile0xdvn#=u@%*@*O}O=)6k1%hV)3Y$PqDk7qCvEAuamxM zN_bYBT%py_<6udPfn~MWQy!k4?(~*DE{7VPSky-~s<1z~h|mQ`KD^7HF48xaX^mop z-#@Z-gS@#xv`Xaty56jJ`ZWNAb7t+chczvJX1Y0piflg zhJR9+ooYugrLzz*!b5TX)3oC&H+2ZNahwA@`4-${%=Y9S)%wc*s&M-lX1QLo0 zZrxS{lp45~wuopd^*3e$xEG}1C*C>|c6WOAQ)R#(?HQ`(9rYXs%>Yl1$8bt|C)U&J zx5`&ta@d6~?-b$;XVQ%`N4qL3+c?suQI?AkMnz0CcW-L2_lwZX+C1Lxi9wP56uxfR zlI1s_xcDY_w><6@w}C4tsa-fw_VAXVadbbcG`X~v<6w>*9Hf_7Vn5U>D^{z34ae$w zWF^ja7<4gCE83B1n^U1?<0hHiHgma1*?ZzK}lv>H3|FYh?plyG`0Y3OOUWBN1@tRfvDM_)h zC;vF3bG??wDo+od!11vyg|>D3@vgSY0*zPeKp7=I6Lp_FNz0!q*_4B=A!)FZ2k*69 zzC@8eE_lePry)t5>{3Bc-JS>W&D@onG_d*xU| zuY%Dn6_6rRtv!F#g!HA1BZ|Z=TC_6o{a*#9$75njb+U?>JL<4N@Kci>%Q#g3{tCdUM*rW znL&)F_c^+qS1fDe1GZ+*?pEF?RW_+I5GqwfFVL8o(6+Dg_`-YI#neE-CZ+_8ZG@IE>%4{--Z<@I!W`=cB_@(ww0Z?F&?$;!&~6_)w8TTS zKdnEILZd29s&!GPyL)^xTP#dgQb%%19Bi60N;0WaHW-79k!FA2m`-C=`aKt6>I1lY z3wB+vd9_b&6d;mkH@{S~vz>WaSoosxUP50)tju#fNz1rcC&UQ9a^&u8Lr#Jd*!1W1 zkxr|4c;2W1jDzn_YjJMs1VyVK-T(BhfD+cA&EaLnSEz+k^~8=Th06}VJ8J*={w)mw zJ8TpR?bnq(F_qcQl@(34_g8I7403boH3y<_g?Gzg{PC?wrewvcwhssBoPJex2B-D> z^cTO%&P4B&3OIR@nY~7R3ez3w0Svut>-#1%9wPd7S!eZ9nzor`!V2Bd4Q72h;o3Pq z?0W8lpQ~%x2eYmwW@xCGL|sFSAid_1E4DD>bHXBxoRfpCIi7WPEhEbZyCW8#(6IHb z=1d&UX&2)ozJfv%Fq8#H(A!loX%h)*y|!HDt;}MyE?==EMmaQ4Re|}WB+JmbVTbI% zGn3w7IeZ@0Ye$@IrZ!4Yb zW9ioU=Yz_EC0v4UjuN2JDN`;%n1}An%0?_FX?^_OEVBNsa@k<=GX?MUnuaqV@48^l z=|hEqW8J@eupvfkUkD#UJ|O6_JU_*E*kpkbEiyYP6K$R|cOs{37d5+b z>={kwqdudz3OR-c+D1PPu__~N4Z0@Z?8AJgrKQQuuu6j^-BoC&lbf@W4GySA?w6R9 zXI+i<6qo4Dx)=yjs-jD2?~NdVUQSjEb8<9b)Wam*C$I-hoUfS6`}{p6R`n{r$(w^@ zCg0a|TUmSX@SI~NIw+Wgqg8xk{`J*5$+raEWhjAqKhmdqHsvsr7w_?PX*3@tex`Fc za=Pt)C}&uAW<0wkG&H#{D0y?!4$56Q$4C<6?@`<`Nbzg(vtA*<E*EvvxA&{&c@%%@5D|$g(PTUL7Od13L}4f!kySSSI@p1 zaN1uyhThji74=j-wP{RhlZh;Q5MW{K;gmD^Y-hw(g@L@qTCQG=V;j0%x{bBn$$*zX zt+bA*JVn6Iv7 zvUb~5kswxQAgg1hP37^gE2|JXl-M96+0`e_Z0W4G_(G3}lgkD(u!Ax?UbYsqWwbLzyRMuTM)e;8|WvJSO=Jq2lBw+u;Ts)xGj{5|52^d#<*%WL5T0DyT7+ zHWAxRSbk=mfNbK>OQv^Lr*2h7cjR8D=IHU<(lOnU=9f3U#o2zrA{sD9sTUF*5|Lx7 zbFSUEx^TE?lk~u|cCB1d%CehBlIgS4J$qK_ECscCvcEBsof|uuqv}+M&7ib-kp6a6 zz+hLbVV>4RgHv8f)}?9-)6a}nho>K+s5z1+1aeggu>*BlsjEfUJ=akMLreGYT&LcS z&``TD({7hbLdChFA!d;CF|RfXyUrp;zz>B6V%_}9=6azymg{bRU8R*>VZgY4N+?Fc z6k!5(vY4o3-Ix9CBM7iGY(+B8W~(sY`1+RMdv`2{S@}8gL%chJPjQ>+I|pXylfU&n zEAIkF%DzLLDG-V*ctu!UjR~q}^G}h9-|$@1JQqH<=S$)xx3E55=d~vp>xKy^(6k%0 zcCA%Adv@!V8!7iDV@{VZ>dib&^RV*WEJ2-(f%UP4H$E&t?NQvII9%XpN%x|4_?z&m zW~=lQ$~B!Wbx;}4_q@&N99SGQ6XC7OxA$EzJx{#eU@T#5;JqYamTuB#DPC7**9A1S zf*TmDNzj%;cBD=`Qmw_KeWk+-F=`l}vFG-7;ldzZ#fU%+4O8$Iy76syf+IcCT!w`m zPbk1$v)hDif*fZ%>lwM6{!kFK z%TK>fm+onr*-6nmU`1=EyIiJq496J18=@)_uJ_oi<0rA|=-OOJ0rt-HXR*aiU)fPZ)GG@PRMN-vI`a+$vNHN(IV+O151(sxPzjlEvZvS+l(Ewbr!_V$}c z(wbK3$C!DW2Q&@sJ@~1mnAuuvF&9Hrv^C3tutVmmu6S%i`bf#9c89^EzM?KqXIv3@ zm#Bh)H*le}Wgwu3PAXKXQeHGB-2GDLk&)5Cs*|*{kBuv9rCNiP-JF$gW!Y3c()A-# zFJVzp@hrQobi1#7c-|sj5J8{mh}f88)!(?X;`jpfOFt|f>~fD8D^G5+321ToAVV1V zuJVv4Dz5!yArex4t&ue147&M@irZD$!VX;(`Y$5C@adfESv5fMrO^o>E)ZnFqDi!C2S(ceaGDw=(*HwDm1G^CxIl*q<7(bNI%`%?d zHM4Kn%Rlq*-)afN@w|X*J+)l5gh$e5*OCTT&*s-5`%H2Y9kg1G0+EY`NXB)G*L~m1 zAG@JW`iw)Olh$b`b15?`kDsntrcN1dd=+`s*~e%6i(i5w=*8#pX%!229!6^*@h@UI z_T{}RB&sA2h2+b=28&3~s)vWKhwNKEb(_hqT{tq6usHGlQx$mzO;Q|q^m$Jz_phxE zk^GE-ON)~$gm&2QEQq)ygd;udfp~-9qi{@}ZzZyh{ySWx4UrPHhy2sx)@J9}imJH~ zhUcLW_6+%uzU($H@*#dg&_0p0t$mAAy;;Y*`$P>Pe(t$NB+aV%*%eHMI8wK#A+KM| zfBBG+Yi9uRuhF?B3j@y>8MX7SQCCJ@s-EaQ&X};|^C#j>HZSK%)r;@ne1knRS$S{k zQMFxbNtU*Yp~!J{{&4=52PC@E5YagFXV#lTLQ)EvR~$-F59FX-2XYuyc`k$GdCo|Q zNJ7Gaj}8ZQQp`wVmSkjM$snHM` z?OwYK|4Wc-g&4^JlNEjR7Sc!Awla=i?jiduieNy0K$1|`kY)u!%M;_pp1#zOW<~^1 zM7NhkFgY|; zr6gt+?o?NdrBP7BBws{HBWY)(uS*T;F*4SvFBWZ(vTe4kzYGCXhZLQuOjIQKGXr@E zl6C&O)kl2nJhp2axtbBZ6YlLtC^IY8X&c(DuFniHDQwZ0?&(d*CiH6^)pR%gD$>8PqP$kagNlcQ_EPsyuk zb@@Fo)`n@aKRG(euy9Uijs1QK z3fFTaQHBAy5{C>}PhC>z!c{+RL00|fB3^+o6JPm*6mc)s6xk7RsEL?pqjJoOg%##MRHgOS+`r!yV`m zzF|&qmU_H)x0q`9$Ov~oH8-*!6mJxVpjfDQy}tEXvzgDM(D!~8ueXMTY#99I?lS`Aq^Ym`1Tk{zMn)0+Sxx(dBd%W1%eZ*|gn%nizZi78>*gEx zGkiA79)u3F0v0&%QZIy?J@qM((hgl6wc5r>BhznGe^LpTk!*kKYaCsGRnu_su@B8- zJ&*RfucWN%62%V)NZA@hD~D#Cjy!17_ylri3Jq3ehXbR+Fj^Hk*(zwug^D2)Gd3$2 zD3;HONg0&4F`!>*6gwFrsu!))nlj#?O>^WNh|_+PXSTCOC7RPVIvQG*lGc1@w`$iq zTJE~yUn8|z7;j`E*=^C@v{lOkZ6ftg8gb~0-|367`g%vB19q(8406`w2~Un`(qpKD z1$~h|Ho}#ve!7v$L#KkKHS3y~{W)<`^lc8kYsHfdYX?^~UhE4KDi&U8)fqnCv$xvF zW%iMytGf}c<5gGY=y`857mJJH4(A&SISsrM;dM?Ul*eY-7wwqCF_)yr(*+wpI6FzC zRHTJodPcr8FJ@YaK9^jQ9JL8OnZGfGdgEe2B zsGGiGgSZex&dJ>)ry4zdSt1B)K8RLe$h z8l^u_b#%M(Vq!Q@V-OU=W@dKGKacX3x45*0izGkWMnukmmpJh7!NlX-5^%c+u9hl9 zNzNZB33Ka-BS~%hdP5gchjM7i$^ByEZ9`Sfd%g-PfJf(zH7X}*#B}E{yF#_&XV*QQ zDt?K&l$7D8^~d!GT}~4V)!Sg$2m#gq$7G$llewy9#eacfrC!ayn?WI{KpBBPTce?H z$^(ToTD>9hD>fe3B^>Ukwo%G9%(f*L2FJG8&TBL@Aa7poIHma2%XO*U(V;taXN|JF zEA7&Xca|EALNd*$Rt;+TIxpTzVGu=Fry~_S)vy&-2*~}o{rH`elLIXd!*BQ)5Xt$M+!4@URwV`f=fH|U*gBhWhJy^JS z)oB|4<>w1DE2L3u!EgU=>uwQ|QUJxYSF2>nj@%&QcoCK<49KJNk@r5m>n?y{cpTcj zdiB0f@d?SNda; zoS;g#0RCjYRh+#~^5X=Fpa_Z#dklbA&RCUFP)^sg_PeF1wQ-niafvzME70IM*&<<9 z;(zD+{oBBUR{}qTk$oWi;D_@^h-V_V;to_P-jAZHcPNiZ15U6gd$3=GbA5h>Op|IC zK$!Dy#kTgvMCdD|R2A8dMBY!789BgM)o`?|J_|)9x)-|&%h)WL%n6F3g2_H0Zf^kC z7ZL<^E)hMa4&Uk-Zv%(4LK8qkjXG)%Pf1ti<&_2VF8g+4u8z zCP`d5_peT?YWS_f*Qtj$__+KxTI=1@vPg`r)sLh@RhcvNa=FCikI%_HzS)Z` z4|$wA^8TxrP?DQg%1@S|+@TsPmfxx7EWzDSV;7Lp2QfO!-Zr;8nt1#M46OwGb|R+g z&7OHp``m{pLPQxX=W6fV>dI)o@^&-eb*H$e*JNH>lxn2HhW!_Qnw@&9*fmM#cs?xw zio@d??Tl99?__7mIQ0+&nL@IjPIbD%Bm_{$=r=>+|7Oo$*Lob+x^&G>Q}8P)30g^o zsm7?nN(bp=0k?iE9{F7IICTPVgH{({%Khdvq~=ncot?D?-*SKJAMxm$8}e>r@8G8i z01@kx!hZ?k>4pWnkJ@zdxK?#GxXIShm8 zyL(#$4u8U4lE5rOSqdS65QRw93I;vJP~8#mpT#1%OYhJFA*{-BvXR=@11_$Jhs06g zfE}})L=uo~J%Hj>cC)PWUULd}0M4kP=wV;oQFDDe^e~+8rSWg1hSv_jME4z;m z#dOQY+0gPt;-ASOB7Y^+)G;JQo9xQ4qqSoP;X58IShz^T;n@7=y4_yCI{6OK1)bSd z^N-?yl5cL1$m`ja_9|eukbUz?{5ly@3Gi0HeebJVyv9AvuYX_TFw10!O(OT()ln=Q zNtR=JlWm@+_NDo6NF&-_>`k&$qq3HHB3>p>#qyIOvce#j!?bBa5<2(xL&fokZAi$v z|1H0|tK7Ng*D}Be!JerGu$SLFLJ-zI!(8~fRTQ{|_=E%8Z`q}bET#a|{L)_@Uf<)p zc*Xbk;CG@>byzF8gZq}D9k>uy{&4%-7)C|VQ!K&l^9H*K*01>yG8trxzqJL zM5E@~5Cq(+(_`I!?n6 zaUN`yho|6F6~fAvHs?5apIOKKwSOS?F$&wt08y|SNhuJ*IPHI}Mq5}qS1b}z!!^Ip z*%M6-pcT9|=8ISS@8V-l@L%}}(cl9=6`9?rIZR2bL|KXb=ih~+vCpCI+pG203KX~G zvwE~f<;SgrE+K+QNcy$`%bSvdAfaMXauj6bwi~9`lgv?HMzsp@<+n9A&R_K%d>zcu z2Dz)33rE8TO%V%sH)Y$I=>wh62)ok&M@ASCeD%Kv%@~y9aUMMsT`;V4XTssB9}|z(q%7s)%M3(XU=2B>Lp>&~%;XmOflWg)}3K z_!U=(B}c0xXAh9Nt{Y3&*uOuhW{=TwVwN4C z?X!ppcP>F*R~+Uy3C_wYv#ild2En8@ z*)4&^e!^sz64_y%S&mYVx(x1TJpT^n`|kJqAL*!bC!VMD ztXB2Bt>$^7E~>g(Vd8@cg$8#`(4JmTbIFHGG2wL=4Y?n1rYC;KmS*s4urX<$e%;g+ zppDd0^KIyM350d>hxCR#b|mP1^)lSKarez5TLDLnySONn=ZpjXo?pRZY8mpUbC25z ziP7`p!E-z2;(|{=6+uuPU%fd!J3C3Q#6PvmiEIOpL*2Xe4xxlZ${Z!1A~ll${AJe` zER4_DHur0uUGm1oqR096>fAf>#~D%yLmtod7^sW`4fF$)pF}ZWmAjVv8=c;ok;RGI`%RqBa+^1g=pUo9Im*Zuxf>EUOy@FnER7CcQXVgP#It|uyM~AYOhuB zv~VaFEEK5~vv4xs@- zl@|i$7uJrbM4dH;PPLlMdirbt2}g-`NK4-);K(2M0QIo)6^b$C56F6kX|Cj@PQ60O zL!r)phJ9D#nj4_U-3f!L$nI9f(9MeXqk8EIFg-mAE|T!q)d|`uBq;kQW{ue50am*T zrMZd7RAN8AJxAM%vd=g#Mwo+2$bm~V@_GM8Xc(pgy@}~6omX&Ff}ap6VT?CMfgzcN zvyYFj`|8d`Xhn}^0&6weY5CX1g9dQ3j2%3o;-SpK;DH89J z*s}v6+7)DDRTP*V7yh&S9y>qXDFehHjw)oNjl?Kws77suv6Ln+0z#e+`*DA<#o!j$ zAZjo?-XH(^<_{#Y!tfgQ7~t z!t;+0%*R$*8vPfPYY?`RwRdy65PtbmYd+JmQX1KuzN{wO4bt|eq=Fta#V#bgw5BHq zBLP6gwx76r8u$-|*HH0wJutZ&mu3Ul8ZfBTj7T=VFCCv`fm&?+E*3(XXP>3!J%rUE zAZ|j&;{(bNKvZMdbWUtS8DT5@V;_S11bN0FVo*}IjZngdl>Y!P?V#){7^`AI7)YE* z(5F|osNFvzNASx26&<}FW{nBhnJa-`xrtN<5i&rxwu!sGu1Q8;)W0Ken)mRY1rI0}2);}3&92`Lq@V^52^xYiN|8CVtN|38y7~(jj?)nsCf)rp9Q%e+rjJ;e{q&xiy5pQK}#2pH#?wc-T zU(8$K0zeCIDF>{VWZhdGqe?@8$TnP*)c`m)tex5ryPDXu6dEamDVcb~K%Dso zUq4)Z9Y}@4S12(IuBwCxNAZKgJo0Q51DzB^a~!5PW6- zW}P<5d(1dXiJ8MyE0REoWG4NCf*7d8SN-G6J5nh_O3Y9LQ$o&nfJQ?z`LFXPNDK|^ zsZ1lTKESy45;uF$gzUljSJ0MBUe=TY=|eJT)$R_~pk>W2Ub&$|Kx)2YgKdfkH5!0Y=3*HrN(p~ah2xb1lJk)E z!YLTC&ek$8tRR3A#~wt*9t}l(?bo|>9eSmbL1GCh4^Cg2%Qls@G1dnE7v8q)<;~^l zpkXXG9AA`k{rbdd5as*O!WH+;ZOz@ zvbJ2l{CPg=t$?`cNd?qRIHC*(xrLTJwe?B**+RvrST0O$0|#ecMG=m=(q7b!EF_5- zSY1+d-6(19;(mznCWLTzbqwQ|<@RKO>}}0}hWuwt%U6SGwx~Xz$VG3)9}HE=ax& zh^s*8Cmbs=`f|CLDaaQ6cu#N?6dLl1=c0 z2X67(p9CZzMu<=s0Ssif=T1UF3^RUUD|fap$cbq( zYmtwu4N1XsSSNLw5D{2($}aZnKz6>IKu7Ss(ilUtP9`}Y1=Qcv%Rl>xrNN>ZEw}SZ zw-5a@elDAdB!akIYkpJ(6U+*79!KypQ3a1P)UnzQ886GEY1g0}9i z{KN+Xh!=<|;FJ#ze|kd7k6~7r#NOZZ#LkQ$!Uy7j$vHE=3e&qVf664g0y3flyty($ zjD-9N)c;qIvQg}IQ5P0r-f)!5snCNqT0KyalgBQbo3|y0V+D-e{~;A(x_xk?n;_M_ zOSf_$jm4|VQ_El{2t&kjhNh?dEE$FQuXTD4m8bqY;jR7;RH^!(Q#YO<^dx*JQ2dys zFA*S+Q-N|o{?;GTr3EomNc412DAd)0!~3S!AvWSMS{BIqA{9sqGz!YTz3bQLTNF9D3l?Q$y4> z5b0=-j8&^Lj5*WVz$O6*7fB{es>M;gaI#=f^F%QB%Ffs$K(@wmo3jU>`RMOR+8yZ* z3BYz2f>^jjCrFu5|dsN571{QLM_xWw+slw|+2skNl zT2Rq&UnWdI_g8@jU>8W*ijZzt02+U@KrxLGTlxT%Q2wFG3=l#g*K0L~0;d*6oB&XW z7S67c`SK7#Snda&(#ljg{H(+M>*Y6JE-TPZg^`~#HQ<)$C=Q-`(={3@j6Z>02_%|~ zVvED)fsW_A%3axk*fc7HQviC^Fate%<`>{8BUHG<7%Rh%iun4%Qd>X!^lLW&h?DaC zTReWZY`Ne_0d*->r!jxi_Wrew+BuI z+k++`v?6uw%lvY?4MVENbH}>R1TV_!>qy^--KMm3M#cBB}zbZk>u)` z?Sh;K*n?cQ;IBy|sV5MeU}Zfn(eYdB%vvpSAK(6`+z87t70|bIC2Hi3;Gnt(BwJ0& z!skvjaqfTgckF#}cQZ&3&3`a4|LwU;yQjTpLB&RbvfEBT9g=^3I4cVXr-&Q_0{}VG zA=}1O!SQ42eQEPUD702`iwBIsNXwv?6rr_1*r?gyd>{{El)_K}l>>A_))HU(0q)=I zS_-WE_U`6Tm^@K(mKIJHwI?ai$36#7A;X3PGzJ+6Vakcsy91UNA;5_UFL|_v_ zu)@D=17T=ESg?X@H%|kpTEJ8@!73U*DVZNhNANOp><7_EL6)fl@0C`D(+1-vh_nlt zZsavR?FkvjN1rSXkIfwzs$dhwFx>&z`yFH%&@tIv3+OFJdJwwP_ z9EMB(2YHJ?1A+AJ|9+=0S;C#Qhd0*Rc)x!`$T=J?WY+iLDXG6h&NIFbmBaR!PQILZ z0#PS#o#d8$V}9O!Ai~LgzrdH?D_`}Ju%fvq=FD&Ek$=7@_4ui=>lsU;Rao!c49a8W3KHX z#2HrgNLQ$xIYQ<)A>VO-GZls}AeI9>TR4u1N!3xG-}8~`c`}Q_f4uTpoS~Ox9Sp&u z9GHK-2pj|y2HhG#+t|9#ny(Ndbn{eX)q!)3rd}%CJo3W%3HSR_ zf7Mhv1eiA*@cq2s)6oNr0JQoaDX71loto(2mc$aWNnd(ep_Fi^a>x@*cLNBD(Ih)h zkxVPTqm?#!0LKu5 zv)lF{ALNV^hC|6gb(kAnlAVK9F-2%Tmux_U%|KG4k^pR?A*0STwT)o0&jA4p<~^9f z^$DijJJ;#O91w{!IpD8@sEoj56O~O=Bh-c;!+;0iiA?N~*k!bFW?k^%^f0gKZt0%b z$0R<(i7Jy+hgwbTb$@@NKl0E39%9H30AaU?fBJZV;LToRP6((K$Y*vi`#f(@r>i5? z5Pd?5^+c!huN#fSjlyLHBK04Q1#SEx1^?}ao=yN-+7gG}hbqO}2pBmp{k>JW#jl_poqp z;W!Ug%i{O>Wl<*rlL&}Ml0qbAfej4@YTSpi=(eIrE+%R%N^r=4eDFwf!Hyv4AJA-J*%D;+ zUI82J4+e!nB9D-L^Wz1Tgtz2-VilUR06X##KogdoD3;*@qmU;0cYXriBIOL794|^` zwF?qAuWfgjdueBWfs-J4>QWmKoTUS6N)%&oqU#fcxbN5{4In_G@C(ZG!?E`SZk78> zp0%o&Y=bHiH6tz41&*c^`=koG4wOi2$Kv;dq{4X34(=TvqU^Hd7C<0+%wdg`jzZFb z^D+Pl!mbo;kOqT8w(o_&zY=19fK){(_Y*i(X%1^t0vqPQr79spaT2Rx{lD0I^Kh)& zc5V1l86q+dnME0rAu1w5B|}A~WR?<9M4^%jnTs+eMUf#wh(sk*ROSYnOc|miN+Iul zUg}wEJ>UA)_I&Gmx9xrZSZ!VjDfL-{(#7(#9=R^yd9th$184n65c6O-?yM+j$jgO`mrzXh`>fN z0GW#lC--ntsrvlKw2<0-{LC26U<2_>!UHitW?puC=Vdl6Tb@S%YgoENKpTMJODBdj`++72aS;{HdzV;ol7^ZbL;d*35*~X zIZ}23IT4E-5%`HAkDwG%JzAt6|5F;6!kLs0Hwth#;j_M@Qg`Y^g1t=K3&ds*dhrSb z!P`eQr(j@OO)7J;g)y@D1{A7oF(}iO$etzSz7zZhGRhw@Ot7@-=4efkMp}APPllwY+vwApe*-Pl-iw07TXTWvri)s31 zu9Ywn4iIN!j1|H{Y`0wudGxRMUNlrBB`20IN4bfIcm*$tL)R0Qe*8$B8o=PJkUv!U zlgb^fKJxb8)8gxYlPC%>@A!;+J}^JD%*B}nE+Zp9xrV#&Pk3YY7k@ks2M6g@!K`u# zbdda193`J4W-ol~c-?UlxQU6M7=d8oJeB_O0mv?5C?Wzo;r3&DZ86r5&`O?Qe)3xG z&2McWA-6+h6KIehCtVWlZ6~g-gMI=7tqdOhGnZ$?RQ4ae?Rt^ZP~sGj(hc#H{_2s$ zCFeQt9?CWJ(n{lRKj3*Mvp|UF9h~z1hXFKFGSi3Rjz3d=*$0+j2P$ne`~<{i|7DcT z5Hf!NKEm#{fpFAu(PXD2()#7i z&F0mC+cv(XDEId>f1yl28A8L&_9LrN>r(Y+P5J&j76x6$9$G=C+F*}|I1!Bd|FUM))lji`IpW_ydBp5|w zQ|A7O7ANB+|BkpLjv6QycRP#c1Mp5u_@|7J2HJ3%91P?Ci3DHOwhgV$Z;v|xOgKuCTY*^1x1?=|y z{l`~6PR4F$6_73;0AkrSb%YVwc=8uf=*U3=g>c7YnusytDiHF&VLePz9sn+r#-9w= zmwCht8OQenp07rnvIOy;q{9C?;q)rHKYE&nF zrk^e6{Y@^Z09{@tEij}Hi&!T}PZq}tgL=%-js=AN2$k0&dkzh9R4W3k7+K$lK?SQ3 zPlkE78Ll(LVku(YiyE*3@fY=%T)j%LDjE4s^?qZyM}U_TF{_z07LtWTow$9pv+*|v zrelkeHvRwC5#b#(QuDEXr(W3iMO1MM<;xG1(V{j_h*=+@v@{{(WP3jc3hRgK=NbA! zufQ7B2VUd}@vk9aknh?7;?CSTnUKzqW37m(gqX?DF7Om7DUk})ATWP}Y2yY;n?@^h zU~|Ihl5xB*;2$BL4Bw&j_;ECqZ6MuE#5pWIumL#k8|Y?Im07zRk_}7Nj^uTj?~msf zKR}$?0q0;~0-zN;Yzto&kiJWHaMLG*AxK_X_3>T*My|JlOAm;d_}y`lI)^-Wg_Sv4 z9vyCZ&r4i{V0X;D#x2xs+n%OPheEh^0KtDso?ZC71 zU)*^XKT(4GV+cm_h1dKSHM#Ib|4)0~Q~MrpkqwXQpK9g95&SRTG18xun*@zwq;if$ z+kEvv(a*FNJo_6@GiQ7T*s7PPD`+xvjRGTzfl0DRiB_N2liikc{f{pfmSlET($oRW zXIf$ka}Z*46WT<-ruY7(UhJNMq&Q5-4EMisM|(2Z{&P!(Z3G|xdF`AmOg~Rq@8LkS zM@%-F|0EW)a6JE^%Cqvi=T|AflnmgF0l8x`FgE{v5B0$O|C#7Mxh4&^c6!)#aQ z`5&<3@;&qC6X}Kg{cQj9H#AZF&uaYNum<&(!mQ6>^e1HIJxHy{HcqgHP+dzz^$V)Q zH7}L~ZVSMVGz1|4nfxNE*8QONo{&azq|Ks+PejHAm(}k-3zElX*WLN@jUsd7s3FhJ zD@3^?z?0apP+w3fZD0&N(wvX{LJC@7k#F%)-yY0OKOl5Is_H`IAt&-Gq#29kG?Jj{ zSOgN#P$-H1+RX7^Btf5lHSz1C!Xf;cInYj|vv#eYt}pR25OA-)pQRtDSQ2c~Ij zl(N1s8ozjTfBsc&&A#Xi7=!bx9-itp%!oLMf~m1c7~>PR#XzIVeW5mQ&%ko*$@Pfy zLT#J@`hjqug@?PVq^HU7_cJ|p2!)Xwo-sf3|QWl zeWLDz12OB9-p#$;RYBz8v#kFRfYQRTY?c5l@=%B~79MaS_g;#fUDsHFlED*Rd&~}7 z%O`L5(4}L(hfDLoZr6oX5Kvca13A2hG-LyhS>*Z>gT45ZP4t3AZ#ELn6@!Zi6?k=~ zZ)I6en2d4x_rkeF|KHlF|Kn?-REN3!V2&QuPz(9VqF?1VoOuK2k={lIz~iy{C4PsU zMvYXXq!pL=9f&6a<>3W5{_v6Vh-T{vZpT#R=)c-|opgM1*tGvA8lf@(Jb)7$^v>@N zQ!Uzkm*61fyZ#u)Nu$>~k$08adEYQ|a(ndde{vcS*0mVQZ(PC4&SR`*cY&qV$OS1- zHi&{re8r6uD!2pY+jZetNo*{AqTBv4LivJJ0*1e01^+W?Q;FZO^{)vE0QVU&K@s^E zpySCO6V&kiKP=S?W7koRz{q9FdNHW`%NvLyC9!SUI-8$6oGU=K5x7noNUh*5Anm+n z!@pqZBv<)7#ejk}q(c>OqH&avCCOlyRk{N?V#J9J&+~mwsTi3_@Fpo{?_2MamR_3; z68hShckZ%zm+wzCgAH94rZ)CsK??qPK}=q@oqb|Cv&bGB@SjXiwSpinQMKO2t#itx zQN$nHDbM&MS(O-Qy+cgX*hl|qv#)6q$NX8tNjprye+E2l$C}6~2bp*%c3_@`yylf# zOw$dz@hWK1eNe*)j}>=77(81{U%)~Zi`1_f&=N^pBUp7Q%nw$$TPnw@3r=lM^pdQ6 z;(3<|xYwCn!Gs$|oKU;IhK02izmL!*;QvTd0O@YY_56pnUc6T1);|?&V|GSy)yr?{TVXo=j$#a`M zyzMEc;@@5NGBR4UiSrDt<3pk2kw$5u1?hD$`wZG&#k^9^$W*A+xgJ_^tS0zcfxY=E zNy*moiNZHumOi+WKhnBs=ezgcMlOzADLeL@J6`?xzR&--0G8a1XI!)3-)@=3op45i zRsZvae_T8IziU0xOxtM4t$Q7uf;6R?Oc@t82`f5lo3)8{Sn_nf%o$>Jkw~O?c`{zR zy!@Q0!P=)6J%6qyPbY2Je(Al=jE1~gbn5ljo*c?@QD~A%2^MI67OgF(&oUiiz`OT} ztE#wIw!6$4cYO3YC|+IFuJC%<=p zgGqfeyQTADH>ETC49|VDvP;>uMeE^OX(oJ6%b>q%c#w`912+TiG|qgUES>B0og zB&WAkRSVXxUEAE&b~Y}Kd#_@gSLO)r7QJxgO60_Z7jCT->9C?V?|SjLqh#`fhr;6A z=PL_!amTxkPDfsPbJgL&P1l?6Hi4UY{PgJ(KR>_b_PviTG_2In5PtpoweI%qTfDrK z9ESNtGAe0S-**VjqrcwG8oj$sl-)?8K$H9)Z>{d0YoVcwaYrS&oKRn1pPM#IudlC9 z1UIvS;N0TuESXpA>$g@hJdnJs$rVN(`q7b*51&6r!8!7%zCJTU`#MiG;KbqS?{07I z`*PsA+V$_cgshitDbUrONOn9;ZYH+znPlXm&%Jd^9>fAKFK@e-#l)GGj*cK)fQr}l zz`ctC8;&k{*=eDTK^Sa0X-%s@|%3zN?cCtxXQC{>pWsn0~mv~l4M zrZqsAK_zR$l_o|8KYh9m>4_3>c;!b`ohBLTHX01d&`q+ovhvv?t9~ReFOL#nvEk%i zk0c|q&$Fc$IRBhWH0E5H&Ze#(W_6BMD{X-1lmxEu#3yHO;g4dtcKjb+z;zVr`{ zGip&^Z&{ey)~#F5wS^l;Uhq7axTH*V)f!gy^HXy1)DP@8wM>s?ur+}xIVDFbr#?x z67`B1c(2Z?J;8O2@%lf!M@L(RUlmAeY&R@x&*k?_ zOS>`6(p#JQXzFXf(7^5rq0I^}Mg1!0PmT$_RfQiHw2;`({Cs@&piPx|=i{7Fe|n0> z`IQ4r`ag;zBJJ0?m@wl$Hl*pH|2x(L}u$rw>5WMRHEM2N9Ri`nP%XT zu^GkAKc~#3SDl?3FZS~C@!^mTCJKa;&mW?D+kbYyr+_R6hjU_r+C@i%Q04hIk|i47 zNRAWkD^|z3a5H1f6ZRG|0Gry++ff0J}hr+Y;3o8@1h+$b`SvN{3<_H?-nsLC2pJmR;HuG zuIF-abkvW$t+Z*=rqUw=EVzG}jh{apSIFMDc~e_opQ*OC7Vo1PPNN8TA+&96AH4EN zQS&{6)6tE3G35ODbJ5XkYHDg^w8F#Jugji(t*)!<#i;T}P*0;{VixS$wF^q@G6)mGCSUegFJ69Cs-%U$KG?!*rVwhY}O{R>{hC_w+2l)!{zJkH=3XP#4TaCi?36 zgh+n22d+GYRaGKwt!GZ(S#vO$B?t~xhAWD+U(uR4eYwbu=?8|b)lR*9W;j&Y_NFX< zwO6|Bo99}Vh8HV=n z%GTVXV+qViwJ+%HOTQs9{^NbL*My{}qr){PhXohcy4lTrE|>Q_CU-BscKnmY`mGPj#q@H(t@*XgVYuaAn**|cwk(!|6yCwm%eetzB( zXPWDAvi9$?IFy)|GzjpssXwp~AxTeKOh$(3{fTeP5R)?6{c&$Q3I8DM&&HWeo`{4d z#v8Y4_ZcXOS+XiBd@FaPRO69+1*CrXgtSG41mEOe3;p4F|554+#s4dB*|I3J1ZU8L#G_E|c#TI`>l( zY$;Zy*V0!syOa(s#7UM>;2tt?CckOK2VIeUpX`VYf|3J zJ2%cT{^jg}Hjy3ryAJU9`qn>d&~o=uGG4HdZL&`I@a&{HZ)s`f$Ym~}s6CFgBI{Oc zacR`o)$B>t)>ZZ6G_0I0F(|lUT^Vs<>BgO5E7z^%{<`6)$k%33rRkxY#&x=Co{N60 zFf_*d*Gx%a@%2;n6jI0gmtE@~_*GvnR%)%7U*&;zFAB*&#m~SU-7pqdwMs;ULPJA? zcaEmaOwLSB+$vsx3!Gh4+`YphBWoZIL|?snyFKOd<#5<*B)RCBE;ja~Swr~3rAsC~ zQ@El&LFGtBr-KA#@BaP0UB)qA@bp$$-z$1GkU|!0aPyem_k*!LcJ`s?SVX@zr@E*d z9W=Mi!b|lH4KYmro+dBV5E@oB`nA_+XP*aGU*Cu4r5<`UkE(S1H-5(L&~E|*+3TA0 z^~iU0aLdqK6IsJ?Y{m9f9ua5WSg9o_TqiFzyxL|HfhQ*Qa9ewf%q%;M^;uAxxu~a@#FMS%#75;73#Nga~qZ3p?D35UcS`< z^8Bl^BRAXK4;%;u%_KeQ>f&OVXOWwm+x_`73u08l<;w&|J#0Jk_QBlma1@(q*RD}0 ztb&4(4>SJ?OD_4pw*tRw5Vc{X$tWzT|a_jxZ zG!GbPA3LTr)yFYtVJX)Cluu`7fGwvFwaQSVPI)HXM-S5l6SzZ7_D`p4J4p3OIsRZTMn&V|@4nNLJq zo-DJdx-5fRmJGNylr94p1e5YFuo84-~WZv*e| zQ_SZ|2Y@)>-@at7nx?u`2^`C+5QTpwn%zcl37CB_Z`@ZP& zRDn$}2zoq}KzdzWUCpCMj~3M23DjyE=o=X5h1{;!g|P`DT=ZOiKR#PKoR&w+KqPXVVuOSG=T-rhlH&n^@Y5NIwsu#0!uGPAB< z7pV)}A(dON$mqw+uZQ=y{Ma$x(lXIIx;7?U2jBbY4^>tsDK~t(s+Kk0YZiE+@oFBI z&j_#Lmv@`p53cIly<0Qd;j)(d!CMn^w20cJ-uefQR|u)!=(%=g^3&~j3z4VDfHu)m zy18`BF=1ze>ZGzXoj>#~Vpdn@+`E4v1p(f}v6_L`*GKbx3!7h5lswixu0TO4`@rlr zvFs~TpXf;0FS_|=_Nvd)8}{t!a*<-?mSq?o9LyZqHQ{nM=|W7*8BmznoImwqVq+y{ zPgz7>*!4mLdK$T1j8Dk)vVHkz)2P4SzABnUWK@B zc1zOSgVpC=W!k;D+Y#Bz=qWa*X1y^Gk`pa&wXrhhCSftOw6vw(pSU4N22D-*EXm5q zBGp1jNcLHS6*(=vu9q*%yp{8luySKtd(6Eo>?|dq%6FS-zocoxr&s5Nm>++9#&*Yk zp)pDVSv_t}k=X{WH|3tRWd#NK!s_=Lq6~G4d6)vph6zGs2dLG$C(OZd{pf?9&QAoD zETW~ADpo$wx(My$ny^D$`!ToXbiABxUJ;$WJ!;J6_I8tdVLRU*NlHnv>A95sakKwk zO2FWtb8Jarr41s=*u=yl3ee_qX=z3*KLH>3U^iyh9()ym)1a-bechlJECN8b(x~6$ zg~<=6y)*{CIu=hmZQZgl=Yq=>q&#s!m9+MK&A0djHvZ=Q70Ivt_#slx@7W)EuB_Y{ z0D6gWajm9-!zrI;R_+L^@MEPE$;z5`wny%I^U2~;*4=dm6%|*%hlXzYetl~_3zK$L zs6C=nIMM1fDjRZ3{QMnfB?U)QirX z5PaUg{(h8p!dGhXnx+MIAQ*F#B#X&JG${K^KlJqlj*T5o z_a08+*&ObSl>X}OAz^Tu5?MYLovJ4pk;J9V=iuC7#^2=Et*e7Ay1@Q1op|Q%HP?y4 z-i^6=t+U{8FaVGbR+`|TAlfr$&laG2)3_G~QT`Ghy)Xi8QI^mhy8o4K$&w|kfC?li zuHWMBzB+iRa1=T?9yK;HLBQCAD0wY6H?ps9??~6(9qH#&l9YE*&x(soe*%co>Au4b1%JK=GZSuU!0WX~)uGJ=g<1`$wO#(QaOP!kXp_7C~>rIN{T`EqfgjS&f_eoFz$jU2QZZl%o_?Rs2v@;HF7 zW?knpMZqgKn>sQKvzs2z-M%EA#>4;Tbi${+2v$((C?R#};`AFg>^4Q~?B20sDLAYq z)%%P`WIX9)C^IJFqLV*<$hBGAyLZpIt0cE|B6Z!NchxXNU;2J0KVK8IWa6UTm6^)`XyQ4;P;~2Gy!lt zt4KXTS%1iXTG56pE9)lx8e7}+A?-Hc!d#>4m%g`^6l&R=I8o4Ru~aZDftmk;pwQ}K zr#GdiyS_ZI?F&fKTpoAr?_Yi{yu8)u?1}W-obk-)!jru|@Tl22IhppX|M>A^g8K2J1+*K$$6)7{ zRVy?+uC1lJdh+}3*vPHq^iJGTN>o@_OXB!jz|ZziExmkOOnm89ua^JPS=>;2BzY}! zr$nE~q{c6XeH%AQM5SUqUiLSpy1C{0(}}?lA^IjM9JEu z=T~iec>7ab51n7}-hYo%FPqV5?R;8c`}OOWMUgWby!XpV0>Sf_S5neJm8jsmfetrhv=a4Sh;SvjsjASk4bVFBxek z$VUt-DOqb^X!sR zGyG6?{jlDNamm{^*HUlN8#+4j;+r$Vxl@OdRRRpm+U(;lH@2o9=G73w+BP@8yxnf$ z$HFr9s$gNsp$b~-9h_<2AlT!?=6$1*D&TSi`7Y~8@hXvPw-ASPz%}k_;aG49@{0GjVs;M0bwU~dHB}bv3Z&j*p~Vy zO*N1TbpQUh*uU}xoj_m=v^yLc_(RghXeYGhm`6KUzxp9e*%gp8FTYH`aG4+3bxl#Ovah-%%cJBQ9_S;Sv zJq)C&!vOWz&Tr%6V|`CJ-c+9C$55$<7;F=YKDYF$I^U@=QM8vt5rpIhTFqx>Nuzy00x)9mrTeo3g3Tp6*QelbD{yUsnf$!8;sODREaRLab`@Lv~iMh5?1^(K(3S>;V1OtRy|5m zcD;7}`mVfQ!h-=M>A6KkMft$mZohy3IT#6h;H{30qNM<1ZkcB%WWF*`)ricz(5^C$ zA@keP#T}pC@)O{Oqll;v+s8I+PD1iCE>%`*?%+U{6+P|>i_ zq^_!`&QPWM?61Q%rL2m%H#PR~-zyaX#PP&K-s)r?KN;(gJdOVdLZ?|sW5`bdU+HLBm zgzj#MG4y%_)2C0f zYDpG*kGwJ(@`v@;?%=@)5Ykzra0vCUWn| zowu#Z&d8YO*|MLxGe1!qU+DBaa>V2N$nN9~$5@bbg<#yZ9P*jp0iK>= zPi@k^^ogTVUJ}Mx;pylY?LWFKqT+t`swfn0)nDG_w(nDRW-WU5ENAb`NWOgF_LPV> zuPm};HCEDt84yzup+r{Buzz=~y`;L`^00hW<($yCmanai)&lmWf?BOD*PZX#(X5)D z$cjk_9jvKSrtUqLG8AOeq@dJ9PK@o`tH_7OLlcJwg{Z$b-w|6pMRFjfkBhEb$CH|x zI)xk>TQWU0^)QX`#m6ahjG3THMkhNH_L8i3a(`pI40uHjEjtW!Ab@J~V?q{-rAdpm zb(r3>EwW>vymonEzvF>_gwFs8*RSW4ot+>0PYg=)YkG%C$XxBYP;iJ9fzZ=s1b_zmL%m zL7mXfobF&_X}RqARgLp#7~tFx zb@3ucCAcxE)ql7!Of+uYxKwy!H6k)p6Jv=2>N0fDjz=tk6mhBa{DsEo+>d&il;(n> zz%8lwYux*Et4CkaM8~fz-kfU_nc0z}A9Xz;XHZ(q9=LaGLeY>e4sb|hcuG(L0Y zLj0vgqtXfeK_NB==l6zb0pP#=*Pa@Q&lfLT2!NMw&uB?aIL~8*B2_J!i)fXxM+2nK zc;B{yX--bg55Q^<;o*}%^3eb}*Zfy0`T~rkJfRlt>34Z>;lc%WCNlecr@J1bM5vu+Y#)Xhg=b0r@Y1X4i9QSm;F; z=w=wbZn!NKQhSK9>4LRSMHFLUak1%vHA+f>1qB7hdyJ#+9B)*Nx>o`Vfaeyf(?G{G z*QP0D>`fVm88{78*;&>Z>9iM1+`8|1N}za*{pHNH+a@(F=DhIMqtV{YqjgcaF$GL* z-iIZ3?QvId%`%^aXi1p|McI=eWK$rBJ?p{kk5q60QW=lLzgR{Gkc(@2zcZPXv>7VgG(Y zp)ONZtq+aR@dwj+Eo|R{P|G`8RH9ZG+PzEHkBuNBO@7^Wk<%<_D3Yix?~z{@;Q+4UA#&)g9h_c&JMC)b^C zY|n5!uzPrY@UtXSKLIhZfcaR^;Rn;)&TYmE7c98++4jB<+tAREo*Ty2n-I$DESMgd zw|O=dAe&oq*srIhUBIQp?cBn`w+lC>25;$qV1}j!EvKY*l#f{>1H;3XJ?nv2;_}>Y zJ)IYvO=id0Il|w@c}%YrJ2hvW^(<}cQfD5O*K(6HzCcwOUX4)^tG<>Ro%)^SW!u%t zZnlsf<#@F^~Uxq*R&r4|BB?01E=Yq#K51U-QP&~v!JfDZF47XV-{ zqN8&FSk#+VLhFm3!wTw=ZdiPI1CNT4QT|99{T3miX2mty{v_I5(VDpDEm{Bc()SK| zdCjw5du2{O9SLCo2FY&UH~nHyv*Z9fvs0|^^xDr6{Pg85*HaHThg3g4$A$?-decyQ z7hFfrvcSrp)r$kSU#3H2F@@p*?GOK*e)8#wld&&6#l@sq=DCrk(-=LEwoU$4)w7t} zD7|?TL0_aDu;1nsD+QE#B)W5?4L|w9j2kP%%yj4O-E+5ZA1XU=|B9rU^g>34c3)alx5~lSxspV-gfd5T@<}&-A`oCFO2RUv8EwYS-Gl88J)* z)(GV#OYGJhv`DX$jS8;Zh>*Ke>Lb^S!A3ch(lb$lA@2*cM~|YST;~6McTlf*+P$Fp zjj2w*YSk)?htYWW=uv9*ECex#W=k-|C0nabVa*znlI(ZaD$4+yhG4vuStjr9c03q8 zt91pZ=6(;N^*gH8j10ZQvvFjAcyAudZ%nwI8S&@@(_luz)Pc8{dh7L;6v7sl}i4m>*W*$T>!tgcvE z>wI~bHV_E80sAfpY#tGi?d|P#_4UtIR&IoGpnG&Q66!})dwY>_orkw~4bc8!%*P^F zj(p>`ZJR*bd3k%|KTZ=SpstP{cG)v=W*sNe;asXlatVQY)q`n&(4tBHO}=tP;qV7d zxQSDd@AvF+>~DyDb>DXlC{~ytNI_RxR`v)(n^V$necD8QP)$*@2@1{Lhc$^*4}|Zp z$EdT329VGm4D!oZ*SVQy;Fqc1y%bSQbjpqPgr%xw%m zt7EcyKCh~4!wV8Sd>e!$dhSx;#qa;@>gu}ZWACS-p;TyLgL`QA55}ECwKOv`^I>4X zJj(<8raV*(i1Qg?V&LKd2xE$tF3#GiLCgj1g-+>FZo{M2AA2{P7^4wzv3o*esI9Fu zGqDT*-=-IN9Smh{TSn^Fq64NJylfr2lN~P27vJ;$%w(Cjz5L^A|H>CyvWF<-2eNn3 zETZuBdHNcj)7lLGi@Dzt$Tre^H&wYf(K@v^EqEb95u8)OJvr$;;qt{mPG{)i-`tKylZ>*@%3BsVOg z=Tb#zc8XY2^b^{mai7JaE#E@^{B%(7P+|liy0XCHj?GAw05*%2zA@SHHQgzvx~*LL z>I0lVs+y35Ku^gJeSYTFf3SAFyg>iqKx=Y*y!qpnoz_Q?rYV}UTN9fMT{3B7W94m*n{}0EXdL!s>Lu=msFdY z+!IG#=(sdZULm2Vr%#_QiwYi&f%>hpb*r|SNb?{K;@w|oTg{4v?lT%5m}M{D)6?U! zUL>0-kafM!mQ{jUrlHe-S4V$A6b57VNT|CRb&iw;YP;*Pm>4E|U`1h`**IYfF zbppNL#HotfxUYZKoIy$NnfCeF=v6H26_ly?(|F*+5m*%dLOTg1cw@%`xDPS~;1MOU zMA}gBTF$&FEOpaEnRAjHmY&;X2+5<K&06}-kJhi_2@MTZ7FsDoH%-)Xp1NozCegEtm98z-ynJALf&YdW<2KQD zJw=+qK2FVbTUqYhu;Qkr9)ns2K2|!FDv>&y=!N{w4_Orlwdo?0Sn}c=Qg*3DE}(fk zvS85ZiG5sh;%V`ESJep1^Shj=curj-a7oA}P@G1e6q-cS?nSIT zPafkx{`@Ub|En*^oj5HI7Zuue0>Zctxc>2i_={1^{bt&#k&%P& zD{!G1O|Z4CjU*i)9SWu>osN&^5mNKn`LTeVo&5t+95^*T5ZJs*&a4rUkzet0shClr ztEXo?v}@L~}1?}v{btsU*UTYw&>WHn!QIKVdV+qVM4hGY>Q&HJ^f;Eluv@<+ObX^+U$XU}Rd z>XAZejF*F1FuG_IzwjJ1o&^gQunGuR4O_Idw7e?ySnlQJmG4l)=zQS7wk~i6i)3VE zh`Tf-Jbc&i4`3Jy1ru}`;5p3#g=SlaQC+OdmW81J&JmDXzI-8?t)z>6N>BXsd{chx zcKZ~S*O_05kPuFrR|~pEr90bx>GwF9XF6<0;akya$^_QGs3q0M`rJ zNP#19Gr9+1$@GBXC`*0SswJYLqRlT~ZiOMU(R&70JlUazg;&905o)N;1pk?@TQQGC z2ukB>W&iol{^y@TJSCbw0EZM!P+9Hn=0&#UM{_M}jVCakmbR$47{85Zkf;@^Q79_- zO*rFyNV?g=N`+g9*1UTCx*Lp$^ndbOr!a(r5^(nH+0N0h?xL!6d!& znCIII;;0s^FO%pZ3bnMJ`r2y`BoFF36;oc)Yj=QvL<1G zwPVRg6&>YsW!ql|1_f=}y7i`It-n1}pd4|K;$jtEoP!z|it-^F5i?qS^m1ohfgCHO zzv3dkn?5ONM$3|$dBjccJ)XSyk1y}(DKO4-92PgR=T;V^fJkO~>Kzi&aIWn6ktkml zwhm8KRn=D7{idTqvaHMa_-H7w>vgy{)z?!fy!`y;jTWZfCzzO+$b%(pA%)nCYVv|C zxS@OAI&2F>5@u}m==lR}>Kaan;e#N8=I|V4t^!51fCBRU*3t5`3V6pm6wOQpPhqPO z8N1WC=g;AU5(MA`$Y~A1AME-4LBzN$1Gq5oHS*54uHs)EC zr4^n;%%B9oQ53IghH3lPQd3XEU{_sJBjwikBWG8&D+xj)*JzT`P8^M1bli z6(x#(z1fX(Ge-_wtqQ8a!Lh?c66ek$DRl2&b~+*yOG2<%U2g$2u0TFYm0Kjw9h5h6@kl-4NMRM8A+>bm#@4PYxqdm=Wg!;Bi2Mba->v;Ek;rx#`JOSSZEZ_h3$K+i8e4EAJ zFa7twFtH@)+&Mt0M=dP{8QP>qK+0M3Q4zmc4zvTpfOHlpWoYaP=z%soH*Qtm*!Z|L zU0>nE)M4b#$Mdpsrul6lTx}Ww^N^T^#se)!+tR}{ASY`u;$$%nYW-eC{NF9>@)A7< z4ZvXB-asKA-_`Z;t5>}6nm)oTl?&i~OV0Pq8ck%}ym=0A%@OU{gny5%rV95aLQkC< z27I47p*b~ot+3It<5Od)oeHaMD#y7xRed0Y|GbZylz%1eee&)C4R21s@Ew-G|=ZT_fLt8N@bj6hb#Fv2~f0e1X7> zccpD(BLqh<%)U_o86SWrYT*fB=hxj_fq1wR21CQcJ!ss2h|4MdilL7lJwi7^9SB9VJOtvg6K8tAP&wLGc-J1J zTH!%QKYIN5x}F2ek|meI&H^xKoj9RJUMN3bei1$W<42EZaNZu_=EiVA0*x3L8Drd$ z`H9$w9_F1uVBmw6OG#zv^-$Xm)3Sf4eN-eamj-Rd_aF{PY^U(5c#112T{f@0}ft#R8&-aMJT0EW@cw) zeuw~wR#87dwgkMBr%AX$)07TMY z^m>I!l|_fJazy_n;2GU^$QA@xwntq?93%}XnAM;K>}v1efG4pK@wNv;rnLbpO*LpK zly%^IHTUh?R~%u>26>!90S+wg!meZdFq9Vuw;(V7!v{OuRU<&!d|X_lk*vB02SXqdcH?R40B9ZtdX+vU zgiZ~LnsrWE$=LY#g0Tq|2&G1^^Fm&Cb(OzyL@1-6wxdwY&0p&2Y$J~yIIL04l>bO# z&)&Uzb!j!dT~iUfL=+X-AtO0%Ep{Y3i*gv73P=DKrcn|W!dxj%4Ye+zxZs(UXzO@K zq{>Ds$p@Sk?D$~76K*-12-uW4yH25lj~5rZqVb5d9inHf2My_GSr=efp#(-<(b(GB zBE9xP%npf)s)Yhbp^(DKc-OAjTG!@lG#D=yc;klnf!t7JQ_7Z-C8m7xoM zkU*R`ipbH=Z+7P;oH%h}5rs`q&^pgQh*@C$PRh=E6^JrD2(?^7ck=NsCW-@S{8%9w z35m^^#j!mhAtAGRWME(=9$Ie6Riv;t&wvMTE0is$Bs0}5B(0{6+pXTU&XdoZ^Y$zZ zAnoq&3!TC#A_os13=@QSwx`u`2NBr0zQo43BVu#l1DO6U#d#Ph%1-@Age)#jC?SB8 zJy;2pblOLcuBGUJ1vw3Gly-!v3J+cW5#Hy!-C7EdI_A7I zP+N9)cY|Smgz|%Q-hy>i&pAhh{U+%NjqUHOGBS}tl2yGK1AYx{rqR89H7aU*K`$Kj zdULle?p9=C&q|K+ptFvgQ7nMPJufLV&DpuRS%0q6)X|BpBP@rJY=Q5$eYkCELw$X{ zt!<(E|q#5QYG#*2dwDM2$%jDthU+@9*qiik6J? zh_K5=MQK2r<&3>D9RgM5{Jv5h{>*$hbVxNrnkoX2A-fT!FJ-Jnq`KN3rEqk_;=nTT zpfhI{kgD4UBQE5fn?(fFd{|Lp-~Rj`KJS(*7!AUt#h`@V4kFT)pNHBJaC28C_sy|i zBQOlQ{CIuOTLup1*iJ02yho4D!8fIBm6esH{W#uPKuhpEP=g)5pMYmHuZac* zur44r7(D}&s)4s}-)_it9=A${^@cRGQ38POdXQ0_hRyl#(M|ckftO1^iFp`duojVl zyHgm^a|6bIOhsS;Y@(&yvM5+^URY!Gw*qA)C3(+2T|{3R&oY5s83tK}bP#x6}txla>wxzI9p9_MC(kKc-0<>ki{9Y`w~s0`Ci zJ7l%1Ujh4Z-g)!Vn_fan21PQxBU zj7oQZ4o%y zc1~Mx*1QME)!wtG+^T@c4(JIcChg}b>fs3b@i`ixFome1b2crkbQB6vL(6ATE;T6% z5YxW8x*)Vc7mi^(XyiSk^phC<)4IDX=unuTI8cOokCAL-t!A}gEwv{&lI zu;kJjOG%u0oJOf*r>K!Bsmf~0Zns9K_DPxmm?f?O@YQ^NnQyNR55dBzFM&_eI7fjm zZtt84QybjP6CD~dC^et4Xn+zZnqFQ-4JBacu#9rb>`ibMUPw&LDkOtl5T|nD9C6K< zfpiQ?Z_`GF27Qn3w;nur&^d~Z*;YRpgbSG;Q}{cD0*DML)J$HF`mXk$FME2*#c}m6 z3kdBP*g;2;;AZ0@O&FAi2qPvXAQv5x+$n_9i`x)~h--DvYq926*PZO3q8&)HawhlA zWVdoqVz||Flgs(=@W%k5;5fe;^eUj&3PExF+cyRO8K3l=9=t>_W|`=QzleVs#wC-# z@7Xe7dYR(Je$Nqj2T)JbQ!tg86+6kSa{!Owu$Nch_3MshFMj2$s#zw1JH~2p&Bm%L zTuUgZcrV7d%E`%LXl>!sUT#iKJA6)b_XT++SAF82SJRI^b#ghjdeK$g;fhOV(=yU^o)!iL>V1mg8TRHH$ONv=UGGgOd&&xz_*6Ig{fV)ntMgX z(i*4|Qf3%wc8-gSt9eI;(H?w{159W+T|>jo(2kpL?Op@ zpbUUxvtxVzMG)%5Du|lN<>rXguVtjmoPvSFi_*W3|JMDA;*8KhtXdyH44+2DjceS3 zFgWuA2C~-THzA@&qPv4?7^eq6^1q6WXr{@PYVAs0G|MPOPbC00uG3aY3Si*$6Mz#kAop;GfY0 z|C4xDw9A$)BUbJE6%{fwXrRvQ0kJ_L(%EJ+Gj5PLVQ+0(RgWV66FCP{4}#ElLC!c% zBE((|$SZVAOyR=}WjZJg2n@mq(hayiTL|gru%~Ag7#wKdwjY4la+*uQG^|Gbh-xu_ zXd+lsTyo>;N~mtsdyg?1HZ|4g_qPPXtL527r=&!IUL5l8mZ3CifkVV2sX?(xXpr89 zSQ%snRWR@$hK53kNerBHP;9Jefg{4Kq{KHcilotlP>|p|>=7&O`{9U$gt#m^l33^~ zbc8SHPYTw2oxx*DP&Q$tK#z1|Gm_tpbtH{kq!0szB=KGD-n;9_S_FFIIROc=pw5u} z;fSr6Ey|%4lEiX`THOBqM+)=26PCn%>Yx8Uaq;&zAn^aU*$@f69XrllxgrSW2|Tr{ zYVtBRAt9DT1G=qS>ylW4bhqIYi(Y&J=BG1>r9Y88UQ|M0%q`ne5krhr=i`LG72Rgri27uLP22j=yf?bIOO}!`VngY zLhCt5^!OFvP^{KGdLfD-|8gJ=f^jU!Opf5-U=XzumA`+vqD@@w&U&B%yfG7ZIy$<; z9v+z0dm;dYK<4OWH0B}sJ7SiD^NU-2a4w+dh8a8r8RY@iA&@n#y4SYy1RrVGfsZK_ zL<>LwIfJ*h>?a;)LK=|&8)E04YsB;VKS&u_rMwH{o{R4TBmQ81|CP0AUqHz zicB6n3oSfN&wjhn28EkGVJe#^p~#WI0?5vI=+Ge$&4gP9#mT$=un35C+cwb`i@9c# z1=W1y0pQ5Xp;*zlY!^ECnjMd|22=qj@K04u&1Rtg*ildzToe;`SP}fbKQg;%*8myo&z(UqNG$e(i90aD83aoW>b?euJrh`3R5CGMMd*jB9Rdsb+ z$&mImlo47O7DHi0Ge!^2ne<6H+saiH`=RXWpaQPJkbH6c{l<6H`QdNgSQXd3t2|kb zwg_%YU~zG=GcIUB4+keD@XD1dsFo@_{ZSEv%r&R5;`zczs}1UWTip!qaLt}j$wrmp z%}%L?<-$evl3p3EVV&@&5kaY9I)bD6or_WpJaW}K01IGZ|5Pf)VY`vqM^3ps*xvM; zLa_nxoY`=pC_C`B@igjZInbRhs_ly8o{g-ZdTu6(O|Mb&Gg3$`nDbjRbTF}(P(3EM z&}hz#_r|O42LaV&ZZE2p4;Z0egW<}AFU4uDpi8^82O?Skl7WC01G!@Y*5?}Fy~Bw4 z#3xU>kMR|3v@wI;Bvo%|)Wxjeo#3e_9kN87CH*7G$pVPdW=^V@tSdvn9}t?hrzhGq z?}tLWBA6Qd*Gf19@{fJ7nE{h8T^If@Fgz|z+`qAq7%5v_ zsOEKLjrF6WQ0|9Ds)5vET}RGu{Doo_q=s}C?(62D(e%RUbJ}BU9E0X?Y!k3^mr81? ztBDbwMpHn;-zxnJtm?Fs7GNUJ8p~ZafieuZ=3aynf)N2>*!!`5TilNxEs@6*2~b*M zYuA2%O-rQe{@Uaoji*v!b;&(7u?YzlJsqbZ(xgyC(D!)NdS~p|eVa`dCcYZ`x#prE z)LQ0xu(jN&lVJ!A(ai;4Ckjpj+X4k%_ZsyQhi9SaHv|Ld2tg&ks3;Ht4R8e@Dv5oX ze$k>2m`ZmxH&+$`)Xvp49F+t7N%__OgbKxIlqER55W=p3zk?Vfsr-)AP&nGq66i72 zB=!3Bh4l3F5zCYrG16)EOGi0FWMj%Tm7a2Fys~gS9@<7mUG{>|DF<(p-?i*+~5+KGB%K*E2PO-QHacC3dvB0LR^GMlZq0_ zkc3hs5|JT7W+_7<6{*(eaNW=US$VczBRSdVThW59RzdA%^njn+0?jw&drxch&(d zXw{l^8FhvBgeAPlEwqID{$qc zZ{~?mZrnPoZ>vdhbz|0P=kwgwwhBK7m(vldwN=J&IIQ5_#EF+*&o8+SF%;=V?Y0AH zuFdL2@_wlftOKMaf?{nWhELQ3;yj=>yi$R#-K`>k>x~`yOMGd#kr-#mOZ&k5D2pc7 z^E>FR`E9q5Gox6K@AZQvT}Rr!b4xY@OFMTTsa50;MOJL=o$`#@oTd&`&mct2i0%q$r#8BN=gxQ-_=X^!bxfdI{V6(*Z$FW=aS4wgopc4Y(unrs)u&HAhmGdl zwIl;rRKscaI8o@lU_nnZuMY|W3PSH_*U}X%Y@IuI?pEwvTg^)#ilDJ&(R~hgvu~I_9qSH?O^I+i%jUeeC$^SjU8GY+ID% zmsvsi*^p*^=Dc~jj1m5I!s{6G{=mx8dHeEE(H%D;nrGzw$KJ|Y`L=~~eST!TB&`Ds+cNqo5AtlKMywpEJC{fXXqO3ZF!z>17#Sf+wQ zIV#i{;1R8wMU~FFyW9@l70Cy~!nBL}-@bi2FrVX5Ptk`}6V5KHemtDS-{XO5m@&PL z`35VBA1Z;ZbX*osIbI^HS619h_4W=2=MeeVyftQK8|OYxKBv%Tmiwe(V&@`xd>~Mr z_4F+ASNtlQr<+Du6Kid(rlwZMM8>i@4;eCT-c!iV41lC)_nsX)hN?dItLpO+?Wnrq z8WUPkiKkntu;S@2g$feI<`q&s9?S^dILLfT@1r`;k>x~{lykS#!Hw+m*MVH(ns1*uZmhon$ksY`5YgVsTRfzD%;xiYqSK_YDKI%lvWQBI}vZ9E@(Bz1brluRvf7l$#esRgu zhTJuiVdhU$=dyfkYRpX+&)Kf7{?x0whpXd~m?3n~9jogGV4geW20982fMT=mdm1iW zxUfj?;k7Vb9N{~9J?LZx3BxTkHy8Riq74-MP&yZiOjnYc__JW8n~SgoWhbNORNUh zJr`HcETheE#ol$F&-LMzTld;!x zuiw0>Gt{APShUWgMLMx)W)xv(Lq;GlJlo+0O-*d#3oWO$brevldmn}AJe%;m&0LVx z$LH=PE`HIplQJ3#6udr7Ht%v9FaR#$NSmixN6uUN4;ARu7bd<0M1w{KT(-Fh3;i_H6XE3f$W zZDT-ynn2-9`B65}cQ$}Owtq5k$dI-Q8Rz`8iaH^;R{zmt-Ni4`t7Ee$kB{U&inY@( zss$c-?tK1ebaZE?Kt3lC4K@>QEvkgLCyo)a#H7ypqH~k5({U%lwL?NC+kbf~RmAm! zy=N+s;-mwE#ZSuaqoYwX-Ot}&{|~CTdcXmjp&7vu3l7+5EavblLQ-n6B%JckxA&?F z^$r~z!%J#JZ)RgjMJECa#n9^WSHMxoJng9xog zqN#}7*SY8(O4fu{G>H!rpFV!9P#`9ZN`wfZ(^QzS5#~Ypd^&A~HdD-yO@DZivDFNK zo=kLF`kA9sSJ7_RY7Jtb+hW)`Wd(!9NtF$zHYK0o=y@F@4nexQcJKZgf?uUkqpQS1 z;p@3znh;$XI}=uYbw)VxhVX@m-i+U}4_oalGWxSq8Y*V&d2#h4GxED%qY}J#!NV0v z%$dhEXdM}cRc8ps57qH+^m|YoSAdh`cZ~AVEV=pY*)!)%{ogK|JO9-G^P)Gx4O2$i zTePs* zJlcAYpiCfrk{d2z->uo^Na_3|m2n*Di2N;sGQ24bb`J7g4W$Cmqfp=u)**e{ z^F-DE7@K{#e^gCYi!aQy-tSbW_T7w(g_GXS$av?3@u!}zdA*G94wZur;{?I{K>{>@ zV0GY{*GMCqCHH;a@IcA0!tlS}r9JSPLIEAyzvU+0RQ`$#t)-oO9sqmRS4hFnBu@T1 zYVzCze$|>zl{jVjjVsVF?w}EfqvFe=CxJ&NgxasM>3YiVjKBW;FV8NlD*Mn3qvxIA zC;HWoUl$kqQJeq~G!i!s2%(2NK}(J<+x(u}W3#~BrZw8PN&xS#7@7+a28Y_y#+boF ztUH@S5GX5zQ^}>Zg>yMV)PA}Pzj(jLG0Wpg0FTKS54<;onw$uGm-E+@Pi)Tkm z%U^d`ywvqj-P+6EnkAkJ@w5;sLl|t~IiTr<&y(kZYV2g<$o}+m9ld||fszo1JkU#fqu7O z{yhSXv&$0C5B5v(doxtc;^2GpuQh)S4m^t09o6nG>LH3JQRR||@FCi=gw}j|>NU|3)nYG~_3JA^D-a^s3;CZs9xbBIFS3@JRZ1F;JbQ0R#3oakIY|wI3zpRUz_;yA& z&C%OF{QS2ZW{<8fXY(c3xqo|$ARsDJ8=Dp|a!Bs1rmYptvciPQ@SsJ;lP6P(SF(=i z-26Ej742A|rv!QM&vmbsSsrtaCdU|~J`gZiafOq43K*sx0#niI0S6uLc_$rsN)h(> z$9KOy_%)W-wgn@m)e_Q}xg<8u?Xz&o2J0&7aFr`_&-WCy>Vj`Ht6LS9^NuI?NPR#<@lwezm=i9t%zmwxL-%LK`Kw9g8X_0TQQI>;Iz~+duJ3{+cBiEYfjKSCdvZ?CZN%k;j-ORsh*V-kOL{3 zyXdy11N|coDW$%ReO^d|p->3;Gq~+z5>&FG8~A~qmyf#J(>gOCHx=sv+h4)Q2KuGO zX}E^{s;WBvyaVV0$E1!`=q|eF`Z!sOF|_#3Aiisf#vX>zDIws@X^m|jOFs_LaOIdO zLb_WKg)N@`H?{Wbzo@m4ke^c@qmZSY75?WH1-zb@ch%>CiV5D?u{*EDCey)=_}n^$ z_lti~bE*?juoWAjLLg?;Rq(=f7wTUs_3~Uq%-JVaxILG?pf0bY0Q~HBF2)0?mZH{c z{GYgWE%>$WdyT+Jc^-cBG<8?~s06XG1O5= zyAZ|seGrlb_ML}^Q$xsxAj{3A zrKN_(qcSudDCN;J`{lO2EJYE~fhJT9p=98hf#9aDq9g!F2R22?PwSFfJKIvbEMZ;tC3kn@gKgCO#S>TdK=^j+AsG>sagML79MNlro%lE-11vwD41)`l@|GdRF zP4+*-5c)Ds@vEXF{y`3|Nua%T)*t;Cso`j}j(&sj`c8^t4o@Q^;-8$+_PGMo6&U^I zH{?^)yjWxt28Y9|NH>)~f4)NFtXG?eg52P6RP2}Riu~75NQ$g|%MfZ$yghS#T5!ba z12(_4=99(6Es@n2=s+|;Un!LA7QXoQR@Lq=5*(egIWEjFG-9&jeq!d)RVQPQe6VXx z);d7_i6dE5BGEuT{jOW@P-$Gwy}4aJH=4RI_aC5tE2A>YhrV@UHakuxy|Nlowm3I# z)!Rg&&P|C+e`VG*<|QbK}mPULHMJ1;am{IdcX_liiDOev)5m9W0F|REda5})D2;}tEp-2k0Le9rbK+SP>984jk zxV1bhAlkn_6U0{pb#$_mW#Is2!iWr~pi^3pAU#G$M`zWFG>HwqLbgq8^Il_ii&m|4 zE_Z~>5T1tfx*rfrOED4^{ND2TF(<7avdv1r{jG-&?TUWsD1+rINE20H^$QfMO<)kXW0 z^`q6TBd0^pQ2Z;cr+14>Y_;v!oIg_RjKND@uKyP%S+V9-`lzu3mycaJR&P&sqp$sn z+OSF&WVj_KOE62?!ZApT_pW-DT3a(s9s{UjBWyp!rhMt^YYmH*ww&~`nb}rO>{d3x zj^imcqPZ7X22|)8tzboV%hwXCVVNb!{YobwSc_fAc^o$I z{AdI^gGR|M^Hz4Y2I@?daY%Kt+HL8aptt%yI0(N=dQfKw6E{4NOH!TT`UJO1=U6i( zA<{~Wak>Ho^&%HeF8+#6F9R)waAESY8Sryv|0C=|;A@N61ZU(E z6&3XXBYhT^kJzDEkY%kM+<*Q)g*G!ym7kOZO;B}_$kH!cH>($@4gPOa@I;vo=&V9= z6+}@gmcZBJ%-wxsqbpb1@0)h7Vmj)Tv!4O>-VI)~TuWs%S&61F18%b$^mk`aIv);& zm>`+k+V0I-1wx{QggEFA1c`)DUMoc{j_dxk4zO9<-G`-oi_I=vKsAGKKK=Rg%Tryt zMjY(aBH}=X1CoL}EzJIi9fV&iDHz74%2UxJR?igv9M{n+wDnOj4#!r!nQ#9cK;UlV z*cK|C&fpTB;IZ2iY8=^8di*><`SxC;ZP}|970x{{LU{Ur$T^|Nm`IEVdI(NIH#? z+u+8roxodtxLo6jEL@Mj1~3aGyIrHNYR3*hNx7zb$ya(qDxjGZ>krq(#nK$LEm<`U9Yo9&iU*y2#I>P-{u7fwcxh#7!DmH}ANZXTaCG-Ks@*XqIc87n8O z^FHATaP*i7lW|*2Z3|JKHRYe>M^R%t+<$fHl!o^v&4KS35wiRHkHg>U-3+45BpR=J z7dWXD;``Auqnu}eNz_n%m<;UHty?`amB~}4gcFfJka53^>Huc=;2XO^(b#Ch{CBikYI+Kjpy+K&k1SK z{Eb1oJSxe1$0r#CdeKmp;`0x-Y+h{BmSnffUsD+Vj*ca~yjW9^5RDy^XpT!w-=G1>#6wn=>G~>E$<%Nmu{+3=?8kdzT zGjRoL$|sBZEDll7W=m^pydonb#VsEgLFV_cN84mLT>s&bUdDNHM0k9A^C_q>BK74L z90B|9-MhEWfIInJ4Gh#|_!Hg>^Er_RL%aw93^q8L>A|eu13#Oe$9&hw$mk07Va(+Z zYfc;^V+Ky1DG6Cf({COIeA6tl`wW``szkMtx5Gm~8p=xhpCKN|clLCUdg4 zY3keWt-QPxCNYV{XSuF#_?&&#lrM;?hJBOySD18XzP?4M6K8$|Lxuv7dhK^%K+%Zn!>^g)?7WfZdljY}f=tgveC=TjW#o_( zCEg`G2T>vDeYKT74jYzrl*tUag_CmR%?TwQ^GV<53efV)l9B*2Ix5_@yjO132wW8& z?(B%igUrM5Tgfn$43)8K`zkE#JYz~VA>|c&-A5A6+We7s1=8vdrrBsf4FdscqOw7_ zdXy)E@k1~%)HJ==sK@x={|qk9oHJ)L14^#y)n^_;^20A_AQb4|KC5qwcI_G~1ojUq zulVu97aG||T^dLRcgBYf-$4&cLD4>R?YC{V+7tO~nLKRqdO`(5k6f2Rjy_05O)Zle zk%7~7qRG|sx@E!9rb4!^x|&AWfq}acEva;Y?aB*>Qgm-PtY`+}zpKv7gO3NAq?AoO+J*Ad`<26JL zksGTBR=r*Cw4>{Y;Ag~qmUUia`b|J4Gns+Os8RnK_uxn2Y6NT6b-E#F7Lz-mX9+~} zF6CR|YD{G^ddwJ=xVSj_;Re(|T|nEUJto?`eD}_-OFc7WW=8l-NxH2zmsTh zP~cBmlRN5F#ng!urx`ofyw7g^MY>9|m+6S# z4<|%K4~z;hd~=<#9Dx#qt+%ME_?V}Lr0OrxX^U;>wIsLWGSbG}Pte1Z9&A98p;!U0kh+*NpZs9w;eGoy6IRn6J=&LH8qr3b5CI4; zT2OT}lkR(Bs-Mn=4m_zeb8L@mV)4gqoHTDzg~JEh_%vEmDG1tMsw3$&gAk`}z8A>SpHZssop-+)O7-9KU%N72Ca=I=~39slvI13`-0S%W8AMC$5 zcD9y-!SF`_Atam=ASaGd33h2NV^eSw`b?|*TCf@%OS}81Q=|DJ{1IRcq^<$i#*C{O z<;65iWZ?l2mWbFtV;Xir1nHNknr-L3V8{V*(+}0u->$WRksJUS%s^@?qF6$a?dYWb z7Lz8OzvD2o75D(M$bCPzeBRl9NNjNULG+Fs=hw_#3OJEcTeBuge-Pd57P%+rd!$ge zXx0tXTXbibSL**Z0wYF>GaFfH3Q`eh647 zq>)$YBQGBtwYavX`YZUJGa*S}>U%8tQCuNCE@jk@OcwLoN+Gj#QI6)3RzR0gUcTy+ zFSvK3&))$5YcZ)2S+@z0E7aRz>!rm!s7jW{3u)XYb>qd8#(9u%_Jj2uIr~6NY5dts zfnkd>EAsU~2mS*{GR+<5evVKxdB@>r7|mH7Kbbz18S;Mp`-=zk@#l_Zm+f1U0VbcB z(?r5Q(vE$Vp3mu~Vy-xpRb?{wxhpt$`lr-$q3eE)eOJ5#{l2_W2<>zr<-IkI{W)u7 zL!))SD?O%Z$N8nZFV{MnxufkCj!@p?e$j#%mOLT8w?z$>!TV!V%;W`W7x^bShvYB2 z1sh6X4+%(a##gr;TTKjTg3as()J=@3Z6G71<>sEwk5^gD$a_#tQ%B_c?V;Rl#57uJL@ zAzqh69EC#0{NxjIGw%xLIWmS`IUN&qDxyq@Qzg1G%*JinCQ?bk`m84|L*#M;+YR1M zlu|LNz7%=7GyjrG#`b4htdJgVlJh+EE%1>MQA1i9A0K6`CzuSWL+@^Tuw|pl#j8*7 z9K>L1a_YKaT+K$v_5|f+V}r=;F^u=)1$SWv4=`!f_uB*QqU&7rsv019m~Yy(TQ^fi zO9^12o<(V7_rB7=T5R2;3a8CsKme_=*OMjDTkP%;8jU@jRu7yOwC4nyYg z&f?ga)2G#&H{Zy}mtAS=HFyv;CN89-J{9i~(W-1oBoMy=O_a`)*AP=)+<_Sc%OJxj z>q!+qD=Q@ylKmrd)p8lKWh7il;9nuRSN$%Ih>2;~v13Qs`UdF~b0aJ%;x;1R&uT?b z?3Evm1e8LRx?oIp0PzJwI*~CCu-&NE3z$SjK-Foq^tF$5LVw1s|LjCKsarGdXjGJ) zXU}&Tg;RrgM}R3Z3~E@>knl-16!!vJW*I=@K$}y**$!WWna$32gEn%`UXct?9EcgC z9LeGUcW$Bl*f0$X03ETzQV^!y?><3bJdrxB>4B*xO(%DyH}0;yQYp{STN^3oDA z7Zpqw%61Wadw5*M5jotwUhrVQq88kjhAUUDRMb`vYmH(btZExVI+;j-Jr^HsF&GSt zxeik=!lUVkzll9~I*IiPQ-ERL&W9QK>6gh)v8YKl;4<|=G+GT`Ys7ki$*)s*D0~<* z+I$W`blK{|kYgzDWFY!M%`}^e!{!X-0$I{$v#=`Ix{oF2RVMl=t|Y+BpQ8|U@S}LA z0Ps~{o+~Vb5N4MHK&c#c*8V=jf>=_;_xs`GnOdvS{z$1trr(}Mig!hOLS04S_&A}9 zNXuj*^s&-RoE>DZ8VV@n9}V3a#jr07nB`oANhRQOyFu>CGC>Pq0d^8u6xJy8Z6_Zo zj~gpw{3h$6JxN{2)w_g3w9gK1j9p%l*NVjqLNM&{G#J0!glsSO#kKTMC*ZQjm3e$SM+DKmQKN^Ws3MeQTY06A`=cE!Xn94^lyU?AIp zC=5r#yr#MrP z9t63xy>4Y`GaRu2;l$J}by-n!F7X!n3~kyCaSp&mCa}CtP1V3!@c%A=C_E^c?+4pC z_Rr#$`!JM+j1VCLYUa)8HnA2Ek0Zgx=M}q}cpfZD_ddf{5UsXw!mvqV>w`){L@I$T zN@!-~Dax8;^cE-j?YKbX2EmOwJIiD&Tz4rFK+2kyVEW5X{QTX^GND&Br#i4^Uz1k7 z72LfHylVHV`ur=Uwed7gWzJU0d>PAprSyu_PKR*s-26!~Mk+%OuRcePoP2q8|39Ud zlb7Zd?=#n^e5Mt>IRDx$otP^oUj%67(q*4tbJ`eV&~tE@M{<|yjlw?3{heRdVwS7g z4oh04#I$I0HkJ?!wB&igy55Vi>4DB`MuwlJjLYW_=POHP1Nc~q{#?$6uzQS5ZQ|lr<% z1o+MR@~rM6RD|tohQjPR4Z~8*OU`)-uOuV2D4T>~1LV(}?T(2uUcC>yQXsTrLglBCxc&Cgl5~ztUGS(q0WhwaLrk!splIaWA zlhM>uGaiq;8ygkn%UFWas^kHi)2WXjdw|@gz7BYZ(#iUxJN@ez<1tzq8e={ae@wBZ z;2GGyN;p1^9%B~^40xuK-1K#HOo$r*>lVRENpOCW!?;Jn7RNcIcAyVjg+$`k@4Dm6 zHETAUBtyW+?eQ~GE%PKQ4I6F%K3F<+tS>d_Ol+T%Jim8mM(ZA; ze;QB7P5XC#-{`Dm^i7m9IBLmK9GW0`X((KJel@~XTV5tKp{(5~`absSPhsDqGaD#G z>Hi-ue96vOB(JCQg1C^gIi`w`-4v;b4B_D1>(2#!{q5bY0}PK|*m-n3!CWd&VD1Qp zlO>7q>`DpsIzpgqqeMsjsY1vVtu`%mA3mUA=#9EVFaJZDZNYyC^t*V97N8S9CKCMo zANw9mT$_4f>6^_D_VGJVyD_(Nbb607Y&UvEc?(5p`!l?a0zN?7$7;&_Gtt00GvTHr z=#VV8@V#XKPnJ+-4Qn#jJhhfOtN|R+KQ-QP$INzZwQam=_ujMReM-wN>-b#ZCEF$r zL(>3fHyUhV{@TN%Ue258Nm9=(+rZWGVaGCqKpSub7*lm~mZ6+#WJ(a?6I;tLBt;(8 ze_+1sOAswM!G@8WQP76>!c6P7Gg_j+NXLGSU8xYDj5)Px-CB0x=4!;v-eL2KxzK{n*bhA6SaVdrU z0b}psOlaLD9Eymas-u$n&+m4p+R8U-l4DpFB5xCQmH&=m2<>FK1Zr(*2FgA^)#5A} z!W=At&cO_9H1*}~lglQpZ7IM3*H01BvH@Ry2eX5cKjiVkyFzC6L0=NsPIJ{%pAb%} zp#NuFCrH!BPiOED&=5Kne{ZTllZ?woFRV(wyE?Ypj>`i_?iA#AFhCri)~$l%xPh>d zHFgRr74i9?K#xj+nUQQP`cT^|bwkE~UWi*Cw&1Fnag_xTtDdXcoYkxu!bzU`sm3qo zVtSRnv(3OGKF-x?S=TT-dv{gFKiqapIHhg0xWAZscZLk#B)Bx7%ZU26X$rl%*qRkJ zPGK5I#@ncBHgkATe`z1i_)TXmhUQYM7I;-S%-epxV09z7M?Ke@;2Z5Bw0r^rCS2%{ z$@Gd8i2NqhK1ABKVo*ihH}OkhBu)kqryWFKBH|Dz8|j?Gh^l(VxgbX_ z{6W@Y>mruDytR?@e(fvPpcbVzv6%0q(n-Jxgde6wab*aP-vHozco4HE*LFJ9bR}S- ze)}!w&ktlOFV1~Mxd1Kv@kmilWgENK!|VptW7fJcE6_QVmSs+?mtnUiC&J_bHXq;& z&47yc{_~X8A;Rxi*4uS3m$M4{Z6fd-R`tIwZ1v>APv zj_ll|&YrCDE%F>pi=mLgZ0o1weqn>aof}c8$VL*`?@5f1-AS_@x0H}R0Duhcrzq~C z_P_yRpJ#3I@9CHbjhNRYUG#;t8zRBUkNA5gli-z*ZbW!4anvfQ@}~ zg|)QHg5n-P8ExH12lJlCZ-AIz$wCxnvqkull$12)X@25Df60s7MyAEgArr+||I@Yx zltWi(L!enV6KM1t%>M}qIa${F7Z}a)wTnjKPzrXVz>O`1HFO(f|_o{$|J?4 z_zcKGM-fRpTG@u0Sh6eUncDwkwu@a^)dY$xfm5V%r`tIfQCV+jbuVCvh!xX_A3>Qn zZ{Id22d>X`O_IT7iYe|cRem%pS)_F-UIKv>z4Fl*LJs7O(@ z%7(J*H?y;Q-nZQfX+4CjCFbNsS-m;WFxBQ29ByE~bdi^$=63H-Lkt(F>ZqJnNm}1XN84}I z_C}6f{ij9rnYz0Fx_hm51cu$~Y9AN*&*X}N1Jhgly;Y}A(ly_$$JFgjTkalb(oN;` zrcc^e*Lqt$-#=&d40ZJ;g)yg_&p8;h@?g-HqKZp}?bsH%|N9U|ceJRfGGoH7{H_)) zl=zzEC9GVwOdCc7)dQL6%lC6l6JdhxY@9}kMF7}f6PpZKjOIs1-b*4{&Upr^$6TP9 zr{|zZ>zHSGF37k4`KNim0Rx8HY&QNr#B#-|mDvlQYxVC_*d*n1kniZ{=b8L?bLnJ= zOG!$TPCmQBN7pekQ}NtTI(xH|W^RmKD^`-p($A*Dc1tmA`(u zTeQ27oNU&$d3aLNf$y!>9xe8AVdP1pnb4_@~e+2ZiR-+SKu zwYK+h?~AVQdmo>)=451hRZu6YpON6ml-E{{jysvWWW2}iM(CO_D$P==fm5L$nY6|P zRs73>8@C(YdAADFr>=GHqEh?N|40PdOp1!#7uYp1ru(tG!*<=x&zD4_(Xyo(I62?z z(+aY*5%SZf9fxO+&sTDovXW$~TA@F7{P;ON5;X8qe;%*s)Z4vc`Knb(*$aUi=hSU= znHq8CnbPD2`sSURiTx2cZ@}jam1Xq?*xA~aKAN`pW;>_1?He%=umo4sF58Pho zEHo7<4k$vI>a-c=fMNPaa=|a@wO-Gf+B@wbAkEFa2R;t!lesdxU*Eo)jvVPib+G#R zK(8P~(>+c_$Hg@U2H@=o{vf#5-Q{1bCQZs@N10J-Dr0hjeZMk1Ww|TreB7Rc2ip`> zbknXXXxy-&>HPWn1?#F7`i6vdYQ9Ze&X~tO!eMoDbKAOY+XnQ@fbYK=s|(485Ji!k z^@Y~y^o+mEX3k}`x=ch_xiihCt?@yYm`U>X(()sJhDueWj(6qe=4`xl9nrMJ5 z=+oDPl#~waxjq&fn@-J8)mXK!W_`3ZyFx>Adign@&a|_$dsqH(ZNb}mO|X+-O2UtJ z`t*kLM0#;77-MNEc5t_Q&7xT~#{da%=CQ&NeTZtKMhTen1bO7>$6wk=t=YL#C)4gG zpqHzHUe55&48jlxJ9n6tz2Dp0Gm(zKe(%lbop4cbKbUyl_UsUDwH3fw%34|z?4f*3 zwAb6}y&G1do2#|&=cW87PnKp4S5;Q}qm{#il9+YzfIFH*`bt%dzK!anbLY7+7_5@FFG##-%^m)i)a7Kz)sNpLWJ?cU!F4EcxT7Pb2Z`!wPEF$yA>8 zOY$W8)1cf>oD{=hpS1p(dTpQXXMGbm+nd9?+o+x~uzYfV1G<|u(D^>JUL^U`nR{q? zC8|O@&Zzsva?Xq#`{s?`ub4BWyNg{xd3MeGMT^!&E5dQsR931p?a|%~U)&Zgx(;$Z z$fZuQMM?-J6_ZX*Pb2d@S^Dt7gHx@epPyN-y_Tg$`3=)(&ucA@k>+(a$lnNB_J9gY&U)RSZZ{#=L!sF8eye@bA zNum0Z0=+s*WumRqyOhm-W4dmTzTLH# zo}N-^YR^|+K7INW78!YM(S3pyL(jwv{h>FfOr0tis*b+C3mBz)-!_HWL7SxGgiGU< zpr12=+@e<6LS_9DGe4J2UbJ{JU|jO^$Wh{eAdC1$J$(2u*l<@}U9ZMBjJ@la5C+$u zOdV}!CzHn7XXcJJH?K#P?o~g&wdK6?1xkf{v&v6ra(=Dly!M1k5iq1W zW`|j`t6!@}I|ig<@EBCD9Y!?4>-@Ytkq5oCKgpbb{4?uMu=opf$B{#6Bn5q@HgGw! zWX&6paf+M7OWjLTKj}}HkW-jE`J@4mAIf+`V#K}2k8k|knfJW}Pv8wOg<`ay8#$Jk z=3&9TeYv;c4HSjn=cqOEHQT#yU)Hl;@#p7zU%wgLHrdsuZ{;GF>j%+9-XC(X$nN%b zJD^TjW2to)F8qs7BnV)^$KJy)JM5Q_b@%SB_QrRYeAk%C8Hh=$3od2qU zGaR>&?-x%~nbXO^yTWrqd3iYvXoh!fm6OAgz%PHw zp`eo1mR!CRy@WK2Ot-YY>(N*f@1+&@?s8BV5nV=!k4VIKnA!@bz5V%4;`jo3B1f6I z0@ZCJHxiyuZ5;^@pZWOY6BSlH>`?OYQU9j&6DmtY4pjTF<#p*j)m43u#l=a&UToYt zC2(c7)(X#xK3bQzQx3?%BhdmdZ``-9C1In(@U>fceM7x|X_(|GyX06Z82i7_V#&YanuUySUsr_Pl;1E0mAIJ+)geY;72Pw2STqZ!;;yhQ!O zup=WTqGpjZcg?vEuPE&3(FVK6Tszrec7c}HSOm|!16M-dbT`5F(IWv^%rK&p|R`soN`*3 z^dYlfg1JP^+S7T157$j1v(?0jY4F)O^~0G`=an^;Ar9`oYCVBL`QFK-FMEZ*2e5w z7b|IW5`sM>W>5F z=)-Z{pZXUaj-DI+&C0o;cCB~caI0fW?A2V1QfkYxRh5#ntPEcaNnYbfbC2~u`0nuU zC4IXshP9LG$h2$ogtYOVFZOS#YqQgF(j+aO27_F6+A4;5yx*Q(xSG(+a~*-#R1cG5 zCr{ofGPvDy)Q5sDQ=5HDIbZG?j8IciY20FMK+g7aRnEa7$p!~yigGCginwGlET>40 zO+|fMHdU!z<%>eQ0wpmDx3#yAV3_>!&hMqAChQ}nLwxn}<$6wrhNfm(X6BYNVUe&* z&9{wQl(-@%FYQClxs5z~@lX`Phi+>@@7(-+S>W`p$=dR3%EbqI+L&kN=lzT&2IqNb z*@XA#-Fw%b%TEq%==HehHq-}mnO7TH^lk0^ZS@qx^L?N8$(--`v~S5k^V8GbQ5!d* zi4{VEN^Ar%n4))b;)2#_dwQM8LHn?i#i{M_{FgO6Tpa3NDH+HEcgipCNG+V-%}3p^ zOvhx5N+*M@$;pG6U}`T1ivTg`_%~4GDm&G9V+2*$j%|R=cP4fRS4ooK?v=Z8D=%F3 zS?w}+u1bp*dkXUx1Yam+c7M->D>E!6X6fJhcACPnaPQ+!J3V(i?^B|iax*1FYcA=r zFeh{iBZuO;e`nTQQu#E;QO)?cn-3p;9_IL=&yFt__iF=yO5uj#V1wx`$24P%fd*A} zC%Da@FSB$_etI-*Ibp&CLD^tP_>b=#ro6Br!6F=~3tw^xR2*pWI%NGW+A^X_skLawD70C(e&Ged(P!JH=(Pt80%oZQDLc zam!Bho71$Z!;&QT8?jU2!ut zO{2!}n^7Kg$z#iTi;?kvd%vrey96?MA6F8&CyPq%-ulaDq)eU?0 z-0!>p&=xU)EQfi-N)l z=tZorn5w(!=cTM$H%=Tsz9}|VuW;=(Wfd0}7jDXx`tMeU81sw60=Mbw8Gud|dl5?}snfT&g51t1B0EW*T8% zapcma;OZ&k#x;QgJa+8ZOXY2?vlA@?8}LV}C@bA@DSPHIV&gA(T9-B zdj<~~GP`>;;x&S@66ql5rHOZe7X6Ekp5E)YULB7S^AjH|tXS8d^q{Vy%v~)$x7gQK zFA!n9Jq?Ova-JEG?^#pN?82F_#oVjWtQj|Jy^vEND+IaY8_=N<7}K6SX}?wPCF)gWZ#nA~jEidT9`pFk zMIr%V#cKNWXYqzx!MjAsN-+3Z;$3^uskJp-G7oP-wCde+R^hOYIf~Fr-xsb~Q|0mT zJ0=476OPc(((-Vu+jnP;vTjiBJgBM;cMo8p|LV<~t5p+tRnqHx{aS=J^_kP8Nk^AN z_v+aF@vVx(zMdzB{hg%NU{2$;g>NZqO^4*@Zq#j`mshwZa)x0&&iU>G2X4#>NFCFy zoZkb^RKLZiFQ12mY}wM_;GsjiaTXXjIiqL&-ILOn>ts@lRfZfY84$gFws)qUx{9pj zy7a5~xUjm4u$+Pwm%SYE)0pVs@R!Z|JW3x5SNxsKIEplBC=+uMEk4u?23O0EH=j^f znqFvBIrMb?zy_~+Dru@HUn+b^5i^2s&WaO+d2Y0PmwaP?W{UwNSaAtf= zj5s!j$Hrz|u<~r}$$!!d6lMOqXC3g9eEr%v=yTA5A#=L*5HCSPeWT^emW6SlaTmQ~ zzLrl6?Rvbl6-v}PGw%LUd-U+(D@Z}Ji4&XAxMJ#ir(hs2qUYpO3T8rNGTO6leN9bG zqh%k|V~-ziFng#lpn}@MWoSjb4UXW6N>PHjrp=;&=r4%Fp$;1u8yk= zJ|}9zQ6LDOFD^-&Xg2P6QO<*NU85Y;%(8!edUhj7MU9$@xOboDI;~_vL2>1UcMLez zkgXZc;M7aI4;can0`97Ksq8IuXnFaHb~tCvF*{utIb|hKcuN5ra93uadbdZgzrVlB znZ0X-51l$~+MxJNQyglcvJ<-Lv}R53r-f?^J!5SBDA2#nIp0`w&OfzGEC)ZuKkiw4 zT;J%5_Af4^8r$01o;0<&uI1Hjj>}``BJ0cEI7wVfQG25_qb!=?x>)RG^)Xq6E z!&1*EYNxIJ-q~_G+f-)K5n;gmx@^bZychFz`b^Qg?Dc+Bb51s@FSc*(7*^h->)c@) z5bnOv{!4aF;kkp4238G$osx({Ujk1|Dnq$&?FmYb`bbjEXU&StzEJmrp4a106MOIL zECNvFHIZ4C?2KV@O*r!Ga$rn~Ra3|uKH#h`@l^11fa_1rZ<$h1ko~>vVQ4PC;HXW} ziO)=JTi2mb?4)V(zNL&tU?&0?G@gP*LAgd}o8@8@Wy(q9on6W-h1_dW+3q_+aL2xV zFB-H2d;~V$WI4YvX%6t4vg_5e3m%51X4^)lUo_epzqXh}bocS&V@t;VM1YWnE&@^5 zE8EQ07Hdns%j<+9pg!w<<$BMal?Z+Yy>LY~y`JJIsQEME<}QmU!^|pJx)FMTaH(57 zs*q74SUXnwb3@559M9UrbZ+9sXIBhL`5ceb0*V!mbBR&}jguIAfoK`MI4@KIiHv-x z<~ARbj!u`4G(SSMr3`WYZoq(0bUr1@F(fktY_eHhDsH?obqowfGxs81$*6vs+?*WL z2nsUiNK#TO#$rb;UD}74qgTbn`SJe8pt7!Di%!sL)!$8nRzwJOZYLOzGxNj2pP6(L zNzj6VOF{SdcN;ixe-pobfq`{CO!d7TSGwH8<73}x%y~7IxgE^Ag|wwNFzLKo8`_8S z%Z>w=)T-}g!dH6sxWl1Iz@^R9vndZ=Y~#5gzzLhcbqU2tcpCpp1}y2`3@AX{Nu`_w zE&yb_LJud(ct+Gb>}ys3MWLD2lR4CO?b}bBHqBp}@#*_Iph1e}*$p5kLI(4iBBItQ61&cYZElii?DrM#leIu%;vpPZX{rZ|4u~ zh8-vWg8_7|l^q#qRar}~p zNy zUyI^w!TWnlIMnUMu2A4D#wWzilmlI-EPi02$)Lf5UlSl@=At+BHnrRpbhq8_&+GaX zVHwlCq%SqRpbB%_UNDw8jLKJ7L@2?z3l>~MFTke_;Nca#=+nD5>$>XFwRYM)rnh&^ z+91nm8%zXAnhE+=nA& zPUeYH& zh69HVjX(iMdY6ivlr}r7_cwv6WDCtpbPXaop-`7a31m{E!GpgI-%MFUy!#50U1!(r z$I8h{xd>3qggIxw4CBNN5xJ$4U#Qi#r>C1S_r1iV`et84L!mc#dL2;UNmAkZiT56L zipS3Vz+|$v0nv&w>PcFSwEOp4I486cuUDdzVCe*;$NBkYbd-EvPKhr<3Sn>WcuXUu z&{$vxATEu#_{fTD(&z6t9c!)T{-!p>{UKc#cBBo>ftZjXZle2NoR}0FTfF|M=I@VA zwPtJ93;~x1A6b9>`Y52a2z)OA3Js%-juH+9mc#+HqgR=1#I^!wAvkdUZUh@Z5c>n2 zOJt#S67jior#vzYYHr%<&rmw;XJ0l5dW0~5K{RQq2|1z@VVftxY?Oa?Iq+|P;HDW2 z9g)st(~RcBWADEo&9dFx^?oo^yyE>JRtP_QYg!6@=Zx^U2t?{8-#GQ9{eWi53L?s{%hoDXqAoh`p}f9Pdg4Ddt0{Kh~Y>c7IJ@j*JYK zh`EKOr4MbotbL=HmJM$?3q=Pnas=5TB<07MGcC|XbQx?y0$CuoH-I=)epiRJYYp?b z(L~#Xu_KnksPQ(_SUJQGi$?ujjWlK#bxa(~A7N6zq_T9T2*p8qMuF^5srPn!pd;H) zy0>uaop}M)82Cu+${>z#Y3m>=3QljAzy2EczVbb%3h)^jPMs%Do=6&orUME{V6tz( z2y0FRQZ7ye(!dcABzPe7M$H2!?E)X@mC_OdNU{#jKK(fScb`4W9{z=lx2mdvA1M{X zTja1LF!!8`Jf!P@I7{#pEKeZSF#))l%&L~u^EB&qvTohlO*O)Gq8+U($^%$C^46fV zU_?AzX+Ew9!Kfq`ELRXe{H&&bhPzf%JvVxH*}jdp56N4bc9gky>7$@1-`*a~ESNq$ zq?|NJvxa0`F{l@MIq07+*~?!NR@Jq%w!njnZycDDVcKYgPOo13bbmZ)%RJi9)nzK^ zSm!QY>{?p@?mD|_?21C{tq5e=&|bW92~6kQG?DYJx^-{blqs)dCnCFfA#td@kX~(} zSk}Mvcnc`XOZ71^6BV7n`p>t(j+shc%s#FT*i0;6%Q%b}3)H`)x-t^!N5{fcN@H7) zk_5zB+_IPc1}Rn_R?<{`&GD{j!PT=d>hfXsa)7sB2FJj$KKI?c%5S z-e~XGH6M0I-9KYlO=Vlv#*MM^Q=pjNvPQ>%0kcIfjcb3Z5r~Jq^;f?6vSH)KEJ44z zt6e6*0T96Lr}1`Oo6lr%jo0w&x0VHC2x1^ZGWSMzY*JxWf0oXR7JNqQiS;|lcWQnd zob%@9zb!3CjvwFj=z@<+k{HGyRBH5!2TgtZ^W~70BJK(d46M%jGSSB7DsvLM>|t#mw1jEu&pv>r^hB8nEy>PeL)st2-9m%+yam&^kJd^Y;RK+Zj! zx!hK`&jvGRZHk9D#nw(1F!C*P-hw!wN6uTKwn^QrCh1zDi>t!t!zyEomCi``@@eUJa&i#m_os^LN1aO+ zU))wZE;e>!e0<5Nz>(d2bh>lkX~3lhC*Gn40P7IW0mbl1r*3;27Wo4p(Vr;RpWpu8 z-;B?IShn4lnM3=6sguo48Q(W+Z>G|Ey1|#$+tulnyxdo=l*M$QFx2>}XbHeYoe^8} zZm$j-%1;<*Gs`EtO%P1;pV=3i33idd&#ZH`-8!=KoomluVDH&F(4*_yZjv8mp$wOD z!He@0ap}y2qSb2k<<;)gAlk$x73!mu`7DDQv@pA&9?=1YO8$&va)f7#mf-?#dI{>d9EW}@) Date: Thu, 14 May 2026 18:49:20 +0000 Subject: [PATCH 08/10] Add missing README images: llama3_8gpu_tflops and lingua-1b convergence Recover llama3_8gpu_tflops.png from PR #1500 branch and copy lingua-1b-loss-curve.png and lingua-1b-step-time.png from images/recipes/ to images/llama3/ to match README paths. Signed-off-by: Savitha Srinivasan --- .../images/llama3/lingua-1b-loss-curve.png | Bin 0 -> 198651 bytes .../images/llama3/lingua-1b-step-time.png | Bin 0 -> 193488 bytes .../images/llama3/llama3_8gpu_tflops.png | Bin 0 -> 103401 bytes 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/docs/assets/images/llama3/lingua-1b-loss-curve.png create mode 100644 docs/docs/assets/images/llama3/lingua-1b-step-time.png create mode 100644 docs/docs/assets/images/llama3/llama3_8gpu_tflops.png diff --git a/docs/docs/assets/images/llama3/lingua-1b-loss-curve.png b/docs/docs/assets/images/llama3/lingua-1b-loss-curve.png new file mode 100644 index 0000000000000000000000000000000000000000..a782e166a14611f29da02f29a8a9c6822df38336 GIT binary patch literal 198651 zcmb4LbwHHM*IrgY6cGdk0Vye!Qc`*i1VliQu0`pV?h=~@0cj*eq`MbT>CUA=>5gTW zhVNY(@BQw*zkhUD-g)QDInQ}c%na`zONrv0B|8fOf$+q|?#qBcXNEu^teZI4z&A8H zx-%dU0Z9D*9a&q9nSLx+a=GoYMW}pHXP`;mT&ft>wYj;AHa7f2x6hxS!$19&Vov#* zd@{srMaE{uI(~X?smZiQOE*)TyE~Px+w_@#$jX4gnx|IS7E)62lnVg_gn@bdAL@3j z%T!<-YFE(le_p{rwDFyugPOsk8SZB--NaO@k9+qh-T!!g%oRw3D=gZ|NfR}i!%hS@B(k1 zSSy?b?~?+ZSk^J!Q!is(rU5kGT#lY}0VCS(P%uDoW!Lr>eq^kJ0f*vP3ZZp0cuV)^ zpV~dC%fC*F=?cQI#dtw50X}B(&n#sV$8oVc^PiJOJaq*XH`?`)a)BchT>TCebs)h% zXc*#(L*Pn6efp)zA(Vho5WTZEZyc`t06`}W+iZsg*tEDR5||G%PNebI_5L0FFfi{z zT`yq!qL*rUbK>l=rQ%Q%D1$w*4`w@55PHC&xg%82OLf{L6+~Of4)gN$6K_D9;h2*g zz00kc1SkChp(}~T=~r{dmx>hi1r^JKf)hcwX-&|LKG^_U5pR2BxHVH&_s=ll{NW`UwCLhhh4WdlK-0 zJu&~otQ5=;kosvVQ0V*>uqWYv7yw9i;0I-Lt{?%03%SR<8OIsp!J!Th@o_|bd-$LR z0X&B5>JA|xz%j;n@b-Tqm3$&VzY%BkkE@8F7`%Mt--IWi2C#rt?G{ZAxJW42^X$J3 z*jvcnY6Soi92b9Mc>;mdtC$nF4(xom?32*OL6bdrLZkT-FqAqQ^D^GQ4A^&}T@Vid zdP01;M|pAxWC5r7=?Ov5xIaNFn4i$hU=5i#j~)uq!?}6(p9dVXRdXNCnKDDEcaDFK z07G3V4;kb@{Rd#!6#;@?>q2xh{oc@*Dy{?nkwqW7HJ>kQyO4P($Km_+aBuScDaowy(l6DYT4y8Fz&%s+G z(N05mj}JgTfg^bN^5KJf(P1xi{wU{)^1`$BsJ;l>8R{|{h5D+)6j z!gP8FYY1nO@Fckazst@9b?iISk@AoG`GN@GIM9#R-1mR&#sq2lSF`Mb1OvdKIuqsP zZ+A~97#RGI7(~Au3WPX>+z0am*g=NQG}HbUza=O`yUjxRnbg!4mH#pT(9a<c6-hxW@tBoTl{_}Dd6aWRgR9rh&ML?=>n@|0_ z^mnKM1ruDMD|`phz-{LL_W`@=+*@umJ)o$Z`(94RZ5cg`sYiMw_x`6p(iZ|cw2Euo zV!{BQ`vhM8pF*oL8WR$uUAwMxs8~H04CQ~*+8sLdVkwvlKcEpE1SI8bi7|Q<)PKYl+Ye)Y2?_PT7M8Ld7e0RYM)^fgEfBLDO1BUWHIy=ED zw2pt>iLKl70Lx@h%6|;53nJ9{rxqMZX|#`9c30KZ`B)R^8sLCD_dXq^Fw|mz+u^pJ z63yEeI9d}^@6x~FVfSrANk9$)dx)jZYyav{1qbs0i1S7v{SSHp`<3xA1st+EgaGYG zKTrFl9G=L6QnX`9C1c zwG=Sa8q)(iOg?p2%0D!FlszB5c6ZVH*&7YoeQ~bbe{-NGV40BKvyTs}iNo)YOtkkR zw;-Sbg%fh2?tt99iK4atm*wJI0Q7Ws1j2sq_$WPc41+(MJwt-rf^0y9C5?QgpPnPNoI>FT}ZvA$)4#<%?IA zf92B6i_y3L!2tf=meYfk8at-~HJ`R(_Lx7+qyMs>qs7`yw{EI#huy{o4 z&C2Ll5JZPOTwn?t#OuGjhWZ_vF@C$lVQL9n|E$l$cOPOz8Qp^EY>#6-9O%|b6wPq( zQ}iV)GBPqc4i1U2`UpmHZj+Cpp`r50g6T=2TktD(DLUBi4Zu+;rJ#(ZiIu|6_+LbrBMZ_ z3ihRMVDHp^v{W2{nN7LS0TcH3~rsofKhWq z5|2eaH#L%I?ER(c(YtxpKPnhoC)Y-X<^+-9o;^Df>99jP=)6)wFWAm)Zeig?x>G;= zP-2?()amm#88Z&jYYJLO(s4TxfB{RsF1w?JgOi;C$~h#_{bzvEnXP0g)YidhYFJv3 zsMsCxdE8>W!Bobe_0<{4*85`rUodPYMkD0)4XdNn2ta=wRu|DOOx@%!D&`pEz_hpkh^yHl3V^_T6)kJEo7QE3O}N4B=(v9T*L~n8 zPo6ZKr|x;n-XKpHjRl}UcQT|wC!mJtOj%a+JD|Wh_V}(V>=q?%xrBhWQ?wFE_UGN> zq6Mg{Ao^1p7+7&|2}&@!#=Og~%C_jcM3DgixTIyyRA7NX=3avo#Ch+j>8 zJLT@!5A#f}kI{7=fvE@TN__OwT`>U01bm;)VFB~$sbLr#Z#ua_h^lL9&SbM8b~?N=9-fij$oIuo;4N`L7{4`wc|n^iWPE&LLADg*-6O3Rz?J z4G0^yP#Pp@!y1yD=)2}gc^r~_@~-=!!g+#@$;)Nz%m<7R0&+O4%G_=|hCKw7BxmE${j92+coNI3@0R(Q@_bZZaq)Rxls&6yZxAbKNQjJ_T*nya zY`=OK{yS^}Ac_@w9x7-r(m)J>C=OTlPl4id9cO|)BTeAiUipj-PEjsVMl-hX=4wE_ z?asZ#q$CR!SmBok13L?}K6Uld!WKnQ4Dt#J4t>(6VzWK%T5@$97e0qO4PcoMeHN{5 z=Vf-kl%}Z`Tg$NSnvGEz{U+sPa0Jtk>1eN9k2Fa zE4`{4?v#4vMfET9*U)6>gV zZ#-^F!onq~)sz;ZPl&L~opBhVdE>Eu#o>ERC8{PNf(Bo3NqEfy1(nehF5S8elNJeP zfwJNYZGKg?8E%L%o^@GxW?)HQP70ubsm;bxr$=H(T^q<0`Y`3UHPbpM3yNl!JO&3UuM zY=NzMi%9Fh4<8=-Y+Sl@NeQ*FKA$jZ3-a9XBRJeTT#h zmS8%w$^~2A{G$`YQXh-v)`UZlYb5@Wel%q5i_yK@&Gm*ACoCF=P-1KcdUP3W^aml52RM)6@E6S#ZJ{U?*)v);!=S4?uwxorXwzl?z zfgd{#kEHmJmTU7re8fK2AvIX%^-D^65$*V`Xv#|hsT}d(?x&RNKyGemAfqsS<}uN6 zO4$uLmiewi#O*C-@PdkvDE|Ap))sLEu$B2i>F3X%s|ILS)XJIH8SmWG$=+BThu(Le z_Ppy{ytyzeZN50FP~BrSz1fvC>!G8=)fk;nZcUOf%arCp24dU!7QdZ zjg1jJY2s=W6t2O+!E6d@4kdjkAD3Fd)5DgVR9-f++s?aC`XX2hf8ul@zn~>V!ZnLLvnn9_~1A)r-#-rgD%@@urt)=R#cgGfevJ3s#qyM0D|&Z3fv&w<|yS zr59Lqsg%&i$5%f3I3Diu`63C41JjM!{z!PE?`{)OYGZf2@q0TvRIa-=?#cPt*VVq`n1F}&M6-%AnRh>I>V%*$P zTE1Yz)LhyIwaUN07E;K2HvC|1{+`dODy-lo(QKqzku_2=831J01@@yjvik{q$5jZ( zwHVQL0nrt$H^+)DR~roTy@>sS%{L%`b_cfNocQgVIkb%4YD8@Mob6L-ei2I)SMc+o zhIj99Cc9(bdXn+V$UJ?@Vi(CR6uI_ZOLyG1)&#yxFA-ehFz?cQ!x)Ab@Kj6HZR{(M zoa!=c{mGy|>la<_zh12b1rr%_hnZ2ndd;7QMm5BuGpF zWKiKwW2EK?FL>RhV6__?^Lc4;*3)CT#gPd->NY980;10yDnbQt4|)N4<9fk1dqFRh zBpUO$r)9r}AO`hvM3UOk7PTd$KFA~Y6;wu$j|>VMrONC7t;%_=y7~p8HC}>5mlq-A zjSMp>9%xW+U@Teh;cTg@E(~g~3W+pZ*Eb&*-e$Gwa}HCrt`T9kYrRi;%kfjZo~%}! zs@+sKk(N=#PGsWl`u z4Z*|1`^|SMvO$L^BBP<#**MS!?Z(EFlXFRfA@)n+jMY1o&ds;y>?s{rI>syQ%sNRE zJn}S(U%mtv45t=}R57b>>k$l1br=c+Gj->FY2au;0J!t%7AFAs3Yi+689lv@! z%9(k-bR+1B;to%o;z9`UtLHlP|b7}C3ZAyeJvruNjZx|(G6)<_)P(?pa~ zM+ioVpp1h1w8>Vlr1?flUN!6~$_!hzvZ8`?yuu3(W!kZ6Dh&;%8XGVq7vB24l~Ix; zpTa%=TnRpG(Ws!O zG#@!Dcv-j|B9u-#l=HK=vIAZ5a&3}&*xE|z+()1|=r#=bfiHNU(qV+=3_kvf;#_<% zbde7?GLICCQF5U{r}gKVrOs^li0jr0q6=U2jn6Faha#ty1i`#ZzQ;IdH}wOsLk*$g z^8-Ri#DoTi?1>7d>_!j-U&p6cPC_u0CxkqwUZFF9E zOU00wm`I7Mt)pXKZ`gDW@0#OSP1Q)*DU@lhQUr!6P*PCVX^)TC)*G^FHuinNuyv3O zZSONli7W-O6*D-A=)EE@DmHr~ir#?%uFJE&?)XcJT3c_tma3w%1jk57f^F-23ys67 z1mNew8pc9Ltcyx|Eqc!o9su2q0ZeyfvFvzYHO7e&{k!=94EOPg4Rf&kPz>Sj+&PR_naDr}N`UsP13 zwMSDE1+@_1Zc9>-j?y{Hsc_T2U93e3QLjIS6xMYxG2L7>8;c1-C+P(X{&hc9KvCMK zski};penq7;Q$ZOZMdMlN&ERVMKQZG(6$eDxZfK>qs-b6G8c6D`e&kyDr zPoi5ebpcHROfIFvQoQ3YJw~GM`S{`sM^0wULhLY8s?OV2i-}92d}(DA6-RD0sZ@p+ zAi}hycN0n#Mu0I}wd(}AIqZJlKZA?dVXC>lF*BOcn4HXvuBUo80fHbEGWygjAj!;p z!(YW@{LF?B;IHZ5wU(3fy zfA2aEc0Y}`T#`;jZ-vT0r)4fhv(~&zpN2=3~1oIgjK**p(6l>Rq@aKtJi((Fqj*W?{si_6PZWG!si&-SEBL`?3Hk?C3 zB6aU=uJqe*li&X~-n-LpQ+mZ|=8SpMPCp|q2@5L8W$ZhUA{7`PYh@6mpG|+g@58e+ zK+fxMt?i(Mwk4)MeF(JZHXOYdA=q#=ppY@IUu3Ltp>n?Oys?GfqExN?@-b9l(7*q# zNtC}op;CcCjhzHQVQDJ*0Q&jVbOPix+%TdA_|fv*p)7l#cO$q2Y}Z+8 zDE)oknu2x#@cYqC57xlaL9+xvdD2tUXCe(#Xm6R8e`km!LiIiU{=H`>sr38zyJ5QR6t#lFD2HzyuE`y-+NpA`Rx_v;VLhrepI8X-dvxZ28*;L;_Pzbp{3Ao|BPDqWb$+5(bc{CE{d1g(`pkQUM^G*U^rNGrusH@mnN1Y2 z)Li1Xa5Jemxug`1rkY@+d9*_&rfO11od~P(inVz+&=AUm^=t*yE&mE^M zs>6e;(Wb$cUv93>DxkxKCutpOz%Uc;RC967aITNxTmYlZBar&@=Wqt6TZ|)PW1&1w z>P4>@f=@7Id&*1Sq&|Ae^;yXECA5tn%m2>H^&;`Y7jR zn83<%fH^8exB62{ES-RW>gsS9wS6BRFiE*dGskI+@TIXr%^LA#8x7@ zpAos556v_4W*F>fMLm6s=8NwMZLOXvSXz>Z*Y%;4tj(8Log9abAbAEXgGaeifLy+Z zlpP2yZYbht!4n##{)ujyFP^Gja3qSpcqsV37vf~6LQK>nfnaS&$s{Q$83>0uKS?bb zyD>RXQdDK05^~{#RYL?%X-ztVRCqI8N(rIAe?T#JU|>$aAuOSOwgGdFo}S*Wr^C?R z+&oj&vV3i6@$`HRExsT%8c@pW98kx#4u(5Ep>uUWBDmJdSbL1h03Cj2D6<7pY(Q?H z;wY)mKi0IyxNvSS`NpwxV}orzoSK^I^Ue}g4lQ2oYh!;g1j<{ZTW7s#Ggr*;n$boB zq`uNFJK58L(A3L|SJ9m|Y_-EbqmpZc^remQxk<@x>s#drWzY8BTIp9%x8GK>TYq0g z$I@W&_^Xhz-C}DzOm;fskwe~i2Cdevh(N?Zo`AtoQFK%a9hP=4?@+U|f7j$OG&B?( zvy!c(a%n_$EiYD4^gi!bw zQ6l=aWfc_a6re7{taW1K9QGgK(ZaF5RBQnZQoe9?`#acpS0>ETE-}g5*CTjh0h{N! zeEg%ciNk=g4z#u@(Bvxmwe5J;9iil!vAdpZF=`f;hGFMRfq^nI%2T&WY1uP#9h}=I zM#L#ds5>tZ8+VKWZG-squhlYJ3(tPg-D*LcvkdxfoSm_v9w6AM`t&I+mGfpsm7@(Q z)Df9;@gi|y9eXRP+rf0WgbNB3e4~99jZ2E7nWIbsV>xtyq=w>D-TUlbiq(a2#I0^p zR{=>Gd7IAxz5PC1-;-^5ObTd@;|Da4jV`>B@P* z5eL$fd_Ts|CHK3&Ddl+eb zkiemT7tL|@rVjf53b3GqZ6t;?)7GX(+A@nq&_Id3-Ar9*&a}W7FZ)A@x{-OCzUxsg zq>97l%WMP#fY|Uj@&JZ8Wsth6YI2uRE(4N}@S)Qu{ad`(#iaz-=&D#H2dub@RJOPA z6Tg1d^VgMX?q1qP&O2zgeQ6les$2i{mg4HE)602r1tyNjhCK7yc8<|R-AJMcmR7F8 zHq;o4ihy>_rzf$132YoF`R^VCTFW?)@e!^quezm6amv5hRs5fQjF;lk4l8ppSx$CS z2%s4gm@^%UZ-mHpF-*WG6{yw>WGh_{g^PYjRnL-V&0m#S11L$Lv!$EkTT+sC%Wrz( zH>_%f)h5GLO0Apy@7}%h9T}nXDQXBSEz^wT9p1EPBA)f(Z75DHcUOSg+zSw$;H%Yp zG0>2=tp?kga}bpYuqmNrEm@6Wm>hcrl&nibKYSjlsYN_;*`VM>%pgYKgU&o8W<{x^ zP`g>HJiE}__SqYtVw8#WM(gI?u{f{%z+Y-~;281g&Jxl6ifD(RIn`-&G&0j``4W=?DMwpB(EVDP znl`=cX4W>~D6wtbJ8S;^`csi6&DiL-!H6b7(e;(-aTkvPxC8nnhM>@(8aj zbLi{oEfks>>ds}YsH^MF)B%k;APJhS&(hv;KOd+mbz9u>0jJYQGWJN0waeWUA)9i2 z2g%6x8$V`x6Y{F;(LS`u7!?$@Z5?LRGqcu`u3S#~R5~EmN8ql4AS=*#W#35SUHUG; z=7C))xUq@5)I0AYW@eUA<;1DE6!+e|LOXckTh(YciTcODB>94dD1K8|-%6Rgd3y$9 zHI3*MV^y~GPd%s3^Qy9zPAMzp>t*zAaW3X-`uG7U(!@7y{dp9#SJtm5t|yK%>;?cJ zWAUb>0ghNZVd_yGk=q_a;M{zGsVt(Aw>YfTB#2rlO8p*lwY%8PV{l`(_I#|U*T=+! z$V?`=%NtYYbe%SqZ*|(~UwlF0%A^1LXC=5O-FoJh=tlNRA(A|lWyPm-@P@+-Oa;g^ zuiw(o1F(kh1zyh>!cM;%IP zU(AVvbL*CO0sp_P^JZ`J1;R+8v3g|Ia^@GKqrqa9mf3^%6tlJ74{pp%omWt^D_EZE zND}FI_4;*9w{9d&TSCc=6s{~44?aRyetpPjXpSL^V4EoHJQa7le`hORL`k~o;Fvol ze1pI>VSMVT&yKXW@LX4rbyen|^M;h4g1W5~AyJWXs~XSqLQB8Nv{D01H=rc}pCLA` zs;H=O*|A9LolXP7NDHVdk?8DAtrHmzI&sv)ck{+8GxWL8dO&kWX$JZ@Q{J~SfQRam z*o4@&){)4hfeMMBbvYQZ_ZQvRvfhR`2Bjk z2}+GbB7>et^;zI&bSpD9*Xv}}SSdi!31kqrDo>}Wo-?*UKVz$UQ^d|{(S(0!=r?g^ zwlinWnD=ujypNQ=b?cVU>>$uBEEpXlAMFTM_dSK*RazlLcFz90SyiorX$j3hV+G2j zi4pVc;Wiww6Br53-Yb?Mj6m&?7C>qHW^WG>qI*fB@4!lsYdCt4Uh*^;^IHn@bm<@O zKYZ|CSa=atn`=!Y3Qa5g&EDc{n^dZjcd=mNB|=2Eq^)X}uNE~w*pyLQq*m5Htjj1J zsnw)kvV+gJ0J574VWiJd;qeJvXFnBbryGGaA4@u4jqaSl>eCZ8P+}4sjqiPgZKWFp zH<;6WB084lh@Gx1-NWA$9S9MUl;R)NLU(?Po}09x25q2r8x8Q$C_pWi^8@|Hr0}1> zGIN}AFF^j_kG+g57^u-_>HD)OWDZ#8(3yZ^Jv~ORGBPrjq8xTuID|hvt6#EeJyS)V zptq)Cd~nk*lnh6PqTS;SM^yj+{#GWL~gKA20Nh!>iSKb*Y=Y}pnJu6gSP#88_66}!9Vc% zWi3@juQ-2FwO?C6r3EFH2WiLfOz*4-_i|L;8EL-fT=th_bD>g)7Z)>KQBakT#c=hTB!j(z-6Y zZEQ;~_ifE@t&rX^9}jo>GU&sKN?^P9<{WnqY;oOi6y`F>!k6kWzZQ({xtq6Ncz1q1wNsJizGuJLp}0mV{dmltJPV0gO^c z{tu2{0NYj2ECb{o@B!dYw5LXa&K-t<31+b8-nkV3%Q#qWrDKi8?B|?5ng-k_+?#45 z4t$RM5a_#XW(wq*Ju2ZS0)q0xviPvn0J?7-qKd#KJz|V4y`7wxn1P64Mqga>h~N*# z^hmnCLC--0@(DQj4*1#4mqf$$z^-S(N%CoaKqsZ{@dK|b*Uv02261xl>mWMo>`jy7 zTswh-PJ-`7a9^0H2YddpMFPyAOn1{}pqOC;Yi3wCZ22SR@v(c6k{1ITII{T$aCGcb zZ2L0o^(Vf8J#lxn;P&{!DtdqR9t|iBVA1`Dn#Hj$a~~^TDd^E9PE8xof3$pnQTdm{ z10I;bCXx1Jt9Fo+=JV%TeOcW9!koUUciRKFCT}ju)-KSoC zD2u6(h%5Y3$sR6f-79&OmX<}|=;&1^MQA>rS?=YN2le*ukZeT_pufEIJxf|8NX4`u zu>zSv3pQqeR2)~xo>qYK#LYm76#3d6@1!yTEoocps${VUfWo#=ISRRPx8mwC8TSAGLh`bUI4xPA49_&|jT zQtzcWV{ovKdAQ>LS33t6=6wJV+QrcvMm-#EjP?I?O%5y$I5Br^@sL8rl@@oxiJ8%j;6S<~SbRSRp%qvD#mJ2{Q&wiMFy6MBs8}$l!l7(7|QxDD>60 z_wVl>w;ynhad|IfAHWa2=0Btd9zq{=^!p42ux`Z+8Y%V>qo2Vku>M8coT0`Z;2=hDi z1`8aZJ(>$puOXhu)GlZU%F9gMVxltXhZp4<;~ho(q^;V6;SZfILATdlb%^73D0K2R zF$q>>Le;3!3AfcxmviR2iX;GkUZVNAoByCW^9sCmZS#9v#it0ZDzC7-T11+(4Xf}r z>a8?wT2fMge=B=bt;|=8vwANQ*81$H-F^5OF=OX0yGeCU1sP9PI&J=Fg84$XHznj# zl^B!Y#=dqQLDdORCqvBNR_ugtJRoAYeJlNK=A&B!kk{y>QKUE zsa2d5NL#)V#yua~_^{>MaB6CjUxITO1CO(Ln!rYt<=XR6A9j~snYCyAQA>^ArAJMp zwne!~zT_E%$?owiB;kTav@>x0bDV_a^&U_sZo932v#*KFy_378TVPaVS}+33Qi_9p z9_G7Ww}(p0bWj5=P^b)UZ6*n<*Ly5=S+Lb)+X(xW(Xm*a$%7hw z9v%`81+29vfQ^1c^=_zAjLK6o`LyO@KG4fk$7?!B$3g5ux$ds>3dhpbwTOW{?v6NA(=l|5&6o&9D(T1MkRD}}l_2>i@giFg0dbh9{SkM<{R1S0Oab_QpAcd7E% z#Nf;dEE!lhv^p18$P4gk4@@!x^0?e+WdrjhPX`5$8_yDmaVN$_B&7l;>it+!Ujdg2 zqSdf+$46?ZEXE|X=PWrQ5(^{Tnz;uFSkJ z{iYzsCz4Cy=9lpoyiR$VqtzESBz8sy*b{e5zEd4y6$|VDNm(cDk+dOPQmM6_pN^v`pG{ zY7%r?pD+1u)!Ng#Dt;=Q`5Eiudk;0J&=_K$*=OA^=TAU2&in)TGZ&UZZlk?$Wo$JSC}UC< zT&1-O$JwnwCAeC%wD_ve_GV^A9Emah_MO*dIA7R4>65<39uTzOq%dC`Ni=W62(I=> z=x2k^+qlI)dSP;SO&kXsVC$E^uC<8nsc#0n7*N;^P1&iEew4vr`IFdqE)#N3vP3rd zy5EdGIomya6n=Z(afn`57Fu%d# zW^Gl$^`-=SRY$1gL3`W}K8YW|X_+S)Gu0OeNlBm83~FvF)^U@UUw^Ek!Zgie5@aL1 zem~O1sjyxtU5uDb$iT2_SY{|KL7#eT135X zm-z=>3A+{$VpjK@Fqym}kbY^aM|fn)XU@L}wjSkJvaYaZ-C$ZNqgeh_oNqlpNH(x0 z=w+o#*+&8j#_>73=b|0sfu3{(96@vU3^x{`F7-*0GE z1>8bWx3(?1?ZYMGo%kH=sE#qY_nP|YP?yh^V#^n@$ZQpIsRw^kQBu*gocc{yY#67V z-$8oGA2~yo*r#Y9u6xQ&gAk_!HWhoWiGNGBTeI&nX;)g>6_=4`nT!~=cbxKFwl#Af zIgC5mE!2#Uq~Hq$aZU4JBW))HfDMM|^y+nssUf#5{94l zt5_x)S4ejpo&dcL;T<6rZc~Zi!a}|sCJlT{f{Dj&FRxA;HK+{_^Jfcl4*UXJf-v&H zxBS#be%^Xp6I(S%qae+zfE~YVq2m2>h@ea66AeeO6D>Jn)&+_}LN8-iHQ!AQPQ|e8 z|LuRz#Pi(2Us8D~LCu_BC}Apy>w9nu9i4t{;NHjg41U*M>GJGKo@nOscJDQDzT|}a z$oGNi`8a3t*!YN4T7L=7+R?IbWKPZ)uPe6a&e%EGO<)UtWMkr__v3<3j7~M`wTYyHHIPDCa`PLNBwLHgL%@nlL(R*Iu(UkCpVdM6W)1;y2sE zY`mYBy5fViK$j%vxYn*cPCD0()eHyq@15&?9!GhiZ-u>b;ErCwJ0FSp(rllY? zi*2Qt^o4ZXS0Zo4dtSTKU*&yR7||mw({!0S_ga%n!#lCeVE%^n)zoh(+NjW;xt+qy z<_L^5bK^v!M&f52-fmst;*#=&DDl7v=ZwbX4|1I$YT=~GyE4sWLKGE0soMQ2kPLES zkn9>QOx&-btS?31!sL5jR-H|=w3)cVrBsO|>lVkne6%qQk-!qz(MgWy%A^8QP(uQC zB?tDCAX-hVzYAY6W|wP@eJ{R0j7fXN>HKXByl>M^b=TfUY1#iw3|6l^Cw)_=(&AzZ z>68H?i6SgzTb=5=PjpI8=PDbFwyjW6hr|+Bzm(#W&r?a0K+g8gB9@2GpJL#d8-+du zC?7}>({lnl1?5zOSM$5%sopf`l$Ph&r^p^2IkaTOh!8Q>cnCF9drn)F&3FHk7qlgp z@5ZI8t`M;^H!lOMPThsMI?3AQeR*rQBAmzu<5}T77ch8C#iIwkA3^WH9gB0)TB_`pn;hW5g25{ET^&=u3tCc zIrntc>Wfw>><-$4C-GB#0X7WSl+G2O3>QdjFf{0Wl-I-6J4Wc5z(8L*E`~BB_Z;@$ zUU*nz@0>D`k=FP4#&ar*BJ1=Ba>K8QB!W?v#l27^wh-xL?nW2nwN>_ewZz0?P9j-8 z%sId~{rPjg34;xj&=n&!#D%SI}Y$+VYhM7z($b1voRJOp~5ER1YzgYdSWjb-9 zh`Gh0H78lP*Z%y+#-`xD6>f64OJ!`*eUjN&_1H^HAMVd=p%SE2@?A<-eWspmEfwk% z+2FdYw?0-eRSoXwt&M+f{JvdFC9UN{pa}omJ=Cy_iLwE%BQCgIqx^SN7X3yCvXU!q zC9lthM`AIpJErPU?@B@`3#(zNykX{z@FG#`yih(4Rr#w8DAe#w11vTF%i=&$pR;;? zyBt)Gk=|uQtdX4D3(;#HX|jZHo|efok=K=d7l~M3^IPn?F}$4wcef8_$%H1MP&s|6 ztT`fFGMgTgpL%+09b7yBY#(sxk|2uD>ltmRBv!Fqu~4 zwhj1)Kpk0J8Y9@SUQ4D;GhDJ2udeO^P8@Mb3vF*Q*|n(8KV>UnNGlP%2iyLos{}L1 z$%2k8-tM$OAi~p%f_T3>ZrNZ6+R`R8gZ^f4RL>vZ+k}nDOMBumND_(W%gk;xrXDzNd@WL%FA zUWn)Ym69?lDu(*cI(m?e7E1G6g^FiwyP6$R2AGgS$*4j{&S|D$p<=VTPKH_Aua#ow zwfJM*QNd%W0lAy1nLMV#2^J&pZVS(is7PVy;)M?rTaK0WOi;Jzg$>xZ9 z!9>5xXcScFD%%b+ax?Rp_ZAcR1{9OUVHNg7Nriqebxl$|d@J?I_-j;Q+V%wNTB&K& zN9iCZ!^vP&L$O|zs>-*oiPnH+_p0 zKg>Jc+K|fII3b1YtQt3#eiRDFw^wv(>hfKq{XROL$MTwVsi$vcbUsS>8cB&wnCkX6 zh4}^veoYucm>6l$WWxFgIwv4BDWVkg1({!3GW7Jv*6)>VR&+(|P%5m4>T4ffL{33R z2TY^<4TTDOFAob23edETH+}35BY9}1{XU+;hFrVjQf2|QT98g^5Z;;!i%4$whQ5%l zm6PE!gzyh|C%hJ_LR_=s>|BvQwp`e(xzGxoCmZ)G9oLEO;Y~N-U!aV0x|TrlJ z0rriqHa}J~-O=7+821;>T7!mPa*X`N4CP~M5|78Y)C*PA#YBGG3}g%>e<*V5;WzGJ z@tPS`fa}Aw1WS|+LeHExm)BEHlw#unuVp3}6X1HGR$__!iJeB1mB zj~aaG%MGs%AcNwoirLU>e(Tp#FCi%(uPuz1uBSsnTl6`#JiQAwI11#V%GcM;iY{cw z4=1sF4XX8uf1vSTKvV^28+1Hq4$9C3!EWmseKWmT>F2xy1RZ9z8}cONRq@b;)WDz|$yYpm~Uu+7-(Nv&3*&UsUT zs$eFO653if8})t?Kh3uoKUZumlS0Q{`fK`ze)MN`he@BFs%&$3<%k6I8c{x&%rf?l zhIf2%OrN1H35&`zOfUU4Q4VK-H}R6b;PdD5%t1z|qh`%Y?&mw)eCP4%*T<@$oDhyJ_wghy-4tO1f>LnbSnZ&QM#Eh&#wBaj*^f#kgdv{1NSf z$*o@IK&UU?NATa22OOX{Nfqhc&r||ojNi~l0P%+%<*`(z65C*F|jrvA0^S_ zEFuzbmO8E5>r=W-W4}GU_Ec8&s=Y8&d#89~ImvW%GFUzMQWAWEwSm>x>|B`fNAMlS zbUp`nbbGFGH0#o!1I0aU;BmNumZ_1{Q)<`G%=t?yym@Y}J0{s8e5gQ=GC?^#yCGFUnMDJ>H3f zttZz|b8Yl>kdykStz5$kKE1-mpNzY08awGJ>q0p`FiCi7{A?R6mfA)lHQ5a-@H45M z3XJLp4HO%4;2fHZ<4esxY}GkbgS68K?l2qtugG@`6oc*`VA6$ z#@vCboaNAZvVPXiu$JBl#Ol{+1D@Q6XXbC_CAgedA&{X z@k<+(cF(hR#+fHFW5z#WUlKyqF+FfiqztbAZ3Mt%{bh)G6OVK?5Cce32DKIH#+KKp z$cf34k&{dF%ywBkY}w&Ydv<_Oob%DFzgYujC|oEyx4mw*rqhsC$&~_~4< zsrGyx&W|gc<&x7fR9MtI3(2vt%#EG9)3Z$lMC%HUZIiw_m>ME1xeL~K8_rez)cuCA z@_J6+O~j7pu%hCx-j%Blp$ z(0pZXLU3sKjTUM`JnALcO28w7EveudBye+#61731hqwB%3OBc^2o!8Pb$`3#`@?>tNjf2D{;+bHMdpO4Voe zFGeL^JTrFA(Hqk4em>}YL5I?PSh-D!n(J{YH@xl~F6!QsO%-pwA+Ol;c?qZdp7$6-F|`QzQ*+ z1OW9=InEWhgIN~;r?X!Jv55^JnUMisk{ObKZkt#4`LK-{qs-yPS8sakH9IBnUZk?6 z^@dDsOCqucyQs-Vy88;oVAwMJS+3ZXldBRBB}S1p13WsbtwR`+7I}J)n><6J-eaFD zsIv+FlEo#swZz#%Eo8G+?w@Ohkk_sLZa1Nh8}9ko@dcf+7?1dfJ{uP1 z8RKy~KI^fhjc!%pZBWUuPnnoeRB~>_9M2!%{3?B_l(S288j)?(Qz$vGvHH|ZHGpge z*?F2n`O8eJQtVDitc}*A=zQ zhN?e#CXD~hD}=8fYuum!R^l+J5)d!wZ58mqNvxSvfNP>P7}Powu4*^gkf+^U+ac6$ z#6mb>?Ttud5e}Bs0@CKC8AH>{S!q+bc_Ix$tIN2wki5bk#JR3nt8*pDz~V+4U+5FP zbC#VS=11qVwwb!)y>?tXt6$iKlxX{@3L^$()Zm*da)n9GbA1Tp@KV`Ilemcz!lkO! ziCMzcNYA$?+m5OhYkuGcn#VZbeJx1Ur|yu9>SGj6+es>jq41|7C4Pn(uC`q+>=n{k zY!y?!E#i$`BB&wk-H{VHuQo8Hr?38G)O+67rT;ekQ%UOA%s6qmSo?Fv{$I}6M@})n z-dULq_|4ZFI~GRLP&z)%I1*d5iMzdHd;!T1D}zidVfBd}BMcc=7%#zvL5H(w##t`4Q4*30J4R~yV%?BiZQ0TkJl%N0Y4LZ;>BA{Zgyw zP(Af7Gg7_n3r95XdTZzWx?*z-Rjx;5Q)BLTBf|gV>Ad5qjQjtel!hcKsVTgD-d!|!t6_xJZ-fAq(BcwC?B z`n+G``5qS5o)Lh>CtN*vG;>=ex2x}r51!3MeDQp85xTh6d;h4}&iz;OLO4U8{8;2N z7w-jX?q|EJMsUYelYA+WW#0Rg!>$v#qK<7~nb_SeX5FQ-^+QCWa$|(^@u#D7bbY6s_V-u zS4m;(2=^NmS|9XYIpjipV2m4M$q=y{&T=er2tYsQKN0Bn6m)7P#fM8tLq~D`9ar>J zYiiq~!K^s7ez}(X;6oGMMYDTb$9u0KW&X%WB>&rE8jPz7u4yb(9-Zh{h>_B>#buS zMo7a)mWjS0JfBFSmOD4y0&&J$<|r2Fb>KP*cS?VuTCLrQl(BNL4O(ApytVW+YlLqy>401&<&>EC!T>a0k=~x9f6hNFi-MxG7bDu3*d;k7DJFU22;Qjw;0fSAC zM?KSAvJ!ioY4`hI8`?w8Ti5tj>j&?_`z_H`p|(N!FeGg^L*%KU@?*UG{+XY-_L-Ak zokW)R4wQJpGiJ5lGK&W_y51@PEtkUUz=U^aW*dm^Ke#0e5o({p!~0&4v*}7fe+%BI zxRtwhneiIN^o$zuzS#H_3}Wlwb#MPYx}V*e7=c8OW9MlGVHb%w$`MV<9a^?hByvd@^HDXQS9uD1nooiay*D5@={fkygPQ1uhbF! zP=yM2!rckiHAhbjG1vV;1GIuzN`C$z#Bnm(_kzJe?uDR+6H&xNHcCKM4F8?mQ2F4z z0%W4Tiy@0M!HRI3hc-3%hVl$&fa@k_)`-t0qo7bLY{GB>)yw=PDwb zYA{e&(3!Ct+$r&q%VTtJIrgrs^{j;H_V3@iISnh|VD5*68YKs%MwfbJ3JE|b=rVL> zOQkM(m{6FGXz&K!L;vcY?itS6ek349L{mZ zCtGAcbd5({g;V*tvGL{tTe!)8A))Qau{#fB(1qCHr_;Ca1#8DQr+D)^~&($6*(%1W6k7 z1?6rJZFpPk4(_O1G}-Tcl-T}?=PRW3Z$_?5nBXV3&nUnSjj~r3DqPspJN}~3JcGmj z)+hI1o#|zdm(fs@-VUEe?aqWZ!JRD$O}kB(-3Kzi)>&bfu2KVQX3jw}qRD^%P&;I( z83M;D3$nB0O8!Hwc&H6M)0F+;dzbQWP=UcmHh!hK?RhuB4P6Q)aNjwFNk?ACwu#Jw zHF8`0E;U0T*~hqf)3-mjXe%@4_%1^2-y;Rj!~Ox`Z}rfBEEL=Nid{`}7v65bvJ_PnW1d>z^#Vqs}wXYaRtnG9Pg>!2Ywau z&QA@JwS#p^$jy*YkBeCno#^G@gL`oS*tj+m7%}dLGvv)6d%N${#Q4o|$K%IJYlj*q z_K5cV4-5w`<(jLDnoBU+dEPHQ{qGoxax+xb2gGi!+DOZH$4*UdfL^gkh2ZRWWv=jU ze&PO)PercjAE2$i8KAUCJ1&K;-&%qXp63nk8pma4)}+eOJJ`k;K3t0yFLT5Ru(y6` z^yU^66xfMv`x*3T>{nRcRCBu?fd*Kvb)-d;t~s@Ey4a2T7>ocPEEd07BeMXieN$xNu3L~%Nf&QVqFcKzn_CDPj{5q zt*-sk<(NnT$#s(Kv#Gxx#dLjmbR69NX{3nG;6iW#h#ace_hqkT=4#^waMLHp)&K{*>~`I@UkU+B-LW8Vt2be^hmyQc{O&jef?g+ zhSn5BudHh|K`|pVASRI;c{#T2VK68&DyYom6qeh58}<$MukIR2?MAVbP}(nmDuZmYlNU(BoiP9hJ#ts@PPM+WLS+zE` z6{W7J=;4*;&yY{PmOB3R&FVvlRrom$T%k z6XkbJX7bpo{K<|{O1t*!1`{*WPYiah0opM)#pG;ZMK^G@a?G-M! z${Z&KgSDLULiTAfc=q`nMR}KI_*~_7Z%yXj(-%X3zezeQJ~N7sz^hK^Vu=1WjKjOz zJGIRVLB(}?b#~ssjPpb~>uK1lD&Vso8oL;nqv1tP_bIjsnzW2tWo_G^^+tr_lacAE z5Sx!P^tJsGYO^FJ{8 zoia$fXX9YJ3hmmgjsHZmXfzBU`;EIaLloJVzDm^wmo*wjE5n=re2(ZL;wj=yG@~j+ z^clcVsr#oq8#L~@>T(jQVYybe=d40-tn>4^mdJykh-S4{Jj@zZIG8;^5%jD>#g=lM z5b2(@5W^aKsJYF{=o4SmH;tAd=@TWMxTT_w$laZ)Q*? zi$yL)>z$TvW?P1DcVXE4uOen$dum-yIRkDSIu|u#aRmDZi!Qw!n5lT=;&X=V@^X0U za)XI#)3ic;Pa}|axzTkJQeeEB;r7b)8B#$WOdq(z>8*hy`61tK-N^}MY8Y%nygZWH zYQDODjY+r811%Lr3Fw7C0_@-C3D5=+yoF%IfVsRZ=}@J%RGoW&pTDT6Br&}Fr2=$s zc76@YjrZwvekO1lDV=Qn-o&TQi^%nNu1jXIOA@UV`8EO^|cvop%2cJPGQ*y}Po zUhmb+IOiH%$FQ{~A_N%y!k2k;zBXmnll3HNtHK63V(kZ5Hb^^+rWz=Dz) zU?;>`=4Yv^p#3=axM7&I3TlPqw+(rPSBsO{Pe!G7-Mx=B;yLR(i&rTQmtA1auPiR` zDKXn7)HIxPP3=A0bdGlnJn+JGMs8=Y$q-&YW%3_#{F_f3b(^j*@Hj~ZNj4!V8_%sW z0d>yHwNd21#=>f``Y6h*<-e?7F8(VaYa^u|hfcOxlqP4b>D{v7QvXnaf>t?D2&y8{Oi0opqlV@o+2T$>k$lZl0Tnq zD`vRrb_>bCVik?f36PDQn?1NW3*%d=XW0IyM4z^PLlI}7Ztw4^P-;N)sHsqB))>ikw2)I8+?tbD&G&ZZjMeg2EgTPTX> zwoko4#EQ(A)5`{l%Mvs}XbKUZFmDsJ&Lf~0oXl`UX>k7T$IdE_!u(LUD4@UObtYb+ zhSIo5_JEV}@)R2%_)d{Th~j9uc2^VUtNkr)g?19{&bOKhz2^_wEip3!0aqKNyEXh5 zLUEC+ZgWZl#e-dWA%Tu9sl1KVlb$THyx}7uwl_{ht~I5DL4%wx+?U=M zg&nVFxAwVB>k?iMeO-No%O~vt>YU*IPmo1P>u*KwZQ6JRW>PIaW=&MvzD`Ri4jx2! z4ctONXdWiiN+9L$T}-Q@P=tO|QU;PazI^dMuUm4K`$By5w7C9QH^0N=F(hCQI789#t|m#>A1{G4Zs%}T?@UX?uJ`z2toE}c22~ii065#iz*nP}gUptyF`G7fbrC!^0!QOwqGr&mP*!+}93! zP=9w6A?D-ixR~`u`(>H}77I1$3O(NTHa6=2r|@z>LZ_3Cy2=j>1(;Nkps3Ya^I36I zGGGtzhvJh`(h^+MV(_f#jfOZUFTS48bQkpu1p_p{XAT(YqHrAfR10!1^7Ciu8?mSh z7bGcdjs~uG6uUA4KRg64;F(0XscC^qwVFSE{#+yftdMltIs^F{s71mqt`2HJQlPo+ z_sMcD|pZ%=`djeTurY4~ZC8jZRuQPbGy*H*DR!lN@;nOM>$f%%rdh9${UCFHe% zr|9dt#Rn#CzJf8dC9Z}cVW0KN+C3f-f6q;9Zl)-v=?Kr7TuiQHXv&8f!me0OQg_$+ zfj2(NX9R@dvkDXxc0DJ9yraEjp$%iaGQIR z-D>``p4m+@&dzplF8U$5Cf8fA`wIm-88ZP9ZmL!8-2U9y%U0H9+>ei%OGat;Wg2*B zcaE2VLo%FtN54hcJLEV-zP`uQ8A#bV{*2FNyWamVroX!A(*;s@6x{~dZ}TW^K2DK*X{Zx7Z`)gs;P#J-T$1QcQT1H^xLe|ngRSWAl(Sy) z_Y7n{a2yQTcN{!j5N(GE5+7U{btd_seVoV4vw!eLwKy^j4P_|@8^<5hQ^wr}j7~ZZ zZ`c^i9q%)){xefc*7;wg{jGM^xAxlRWdF9;r?G&N{Ot{!iWrB`f_l_l%12@(&iCt} zJ7eSV15}XpY`1D%=8C@CS7$rvm;7G6*C59G19vJA&_Cnri`UbAs;Lgu|EC4$g@L78 zP}nAymv*c2?&Z&ilUgn=URq%lYeex(&V8@ru%2mEiL@kf1}bu-HcPcA>@wvBFx3sF z{FVNaj}3L5wt45OEy2WECZlp0gQdx+Ih&X-g&y593ZC$}tqP>EyY(OgUY@T@|%rFGBV~^1$0(sWi`JA3(9%w?FUR=d|-$ zo*e4QrGA~*7_rmvBh=f_<8ZzaZuzK zLA(Y5hSj-;2mc2$Kve;Yxl zi3Bi;?&t^o{r%Y(r8NmJeT>06ATQ}XT3r??kow+ZT{VQA7!h_;qqNON>O2f`cRx%v zVs*NxUxGrhFo)z*D%Xp4Gt2zFR#F+J(gO~PB(v=8C;aMNCLS^{mCr4C zBqp6fMgiARjpb#~#?Rn&*^#Z}r^uOhXFrpKG{4o@Y}_y1@@ALv9EI@ijG={N_r67d z?KdauGFrVEEkA;SzpS-2HRYc={8u6+4BR_8efo~@o`%9<=XP#v^Tr~eDg8Zzq0u|) zweBUodBknD@O_NlMi{7jQE;^X&vmC}^dfIG7inQHTJghzD!?LOz*YsM9#!2~q-{e) zCOb2*)J@;nqZ@iJUcJLP(BgyTQZ9(rOZPek_E@CW>rF-S3+9(_D*Jf@Lq!C>N>;jh z6GKKajj-QUStHmjf&2o1JJchSw&-3@@3rL|(ZLxGH3UThO#(w3*9O)DdFz;Eo&RzS zI~EdZWbWzlRJcpQjdG->%kac$ShFqOuV^}3ub|^i*q*N*ta*QxXE3i0)&BbuGb7Ua zazppZe%aZp*|T$egOM;mt=Ic3+<45TwYLq~f3-xY%*)yg&WI2&NSw6gIQ|w@P=`-{ z87JlLcx1Nk-sM295f^^rMPS>@+FHICxUeTDs$3AI_UQ6WWn$cd*Uu0h-yY^6XjgUS zf464(JqF_5iew+i5OC8|?_<3GI_ZL+QeI{<{T=0D)(N$dIH!yWd|U-IM`&UvJo|pC z=D+(7pS-Zz!8d9xmn`SrMR#JkJcL0F&(|@{tyT+gl7nqnHINtGEOSehzAz+gZ?#Bj z`exkel_k}a@^N-U!a(60di2NI^%P+rOTsBbXjZ1P02dJ~{ZE^qA_~&HIWOsGV=JXm znX$Z;fFitnCwTrz*p(hSaKA$L_c%7UcZD?3UAL~@DJ4}jMvZ$B5tEsQxD@A(AcBZ9bHGFru`Ufi~ ztT;fCo?gPwjpaTqU!8qQS!gj)i5@2ILVp*i2)=v$TZ6hTD%hddWBi?A@jCL=@qE&3 z?6>I8PoKtdm2gkh2VfZZot!PvuoxwVXYJGJn02ee?^iC$e%61*QHR2gcpvf0_$xxuwwNe;jGqc8(jlB*nQ#}cHYx>-bO&s`lJ z|G&{u&UP{ukWyfvbg(4$sL4(knL!9EU$1P8?);mqMN%-%y_>VSNaoheBbx&28mHgw zJ=W>)zEecO_`8a=jb)2DW&r^gCJ*;xnj6*zmIKe}6T!+ox(j2Yy<9I#ZK=~#<0i(yDA9Pg= z4cP;7^WcUYiYux3ti|7IZf$fkqFJ3rkEdcd;Dt@>U{83NhwRxd_D}VHvnh>@f1g5_ zmJvTjf2NTp&2|^1v=8`X8J5d>JwvKcwKQq*zO^cZrw{Yj>aBAf!rq6EvF^xM3ZCQr zj|OJ=pY)coF^B(m1;^L(XM#`Dy=o0Tv!wRlAhhSLtT!49ZjEORT(D!~?b8ZOM=oN2 zb@zOjN(*~jOCN5Bxqez4lV3^LB~3jOoe=oK#uOs4Coqb0hE7~l=nAT8V2^jFAJ z_ckPg@oqnDcS}UKI7jEkJ9@82XoTsOoZTOmyNM}h$1D>Gv{kcuHnffBMePcn)a2(E zRBL%sKx-l;bSBmB%Mg)iOrz$0{@84uN@e^93-^n@j}?m0$yCTxqrnj~_$5{@NZydc zZ5i;dWwP#UX0?mgq^_i8lmI@T^=DI*1eem&$M3;GH>eU(>r?X9---(BO`8XVy2&x{ zUZ4-LhHj#v9S5G7>MV^ZZ%Fkkc5X($Ry)nY!otdReSRFENE`d@ibLLZJqp#B6*rU? zR&CmM*9i-%Ds}AW5r0|AFj2o^aWUew@Tp$=C+W8M2E(jq0ag3~bJ!#0Gt*1h`at(O zsIJOj9r!IXxzO=r!d!cA?Asp?bbL!1aX3r#jZ_->qP^{JTygH#dIX@7ud_)w=k+%k|9xFU5Xya2bYiufLldbuNX?MMqN#HsykFoZcCD;2)^*Of; zs_tVr6+P_1n```Qaye7P;IvyT&smaoRXmf!HmbV$kT|m2*MdB2k2p;o=_4+F3?J#K zQS%@YQg(fNU%6uARR!^Q`-xHf9PkZkF8Ez#Z*Oq%mF{^zC4l_;jcgE4!=CGzs2ehF zb)Ah#vf!bAEUNTXcEQWmVT|=c0<>#D@V~$X#S=zp`oUkRS3s;h3x4(o>km?`hME=Q zT$2o!HmTq{!`@ILA;?o&`P(LAa@_#PwQzF(yNY;lO4x3d#XCU@ty=oP{?}DoPFlu? z(3&}TmFdOi3e~rzs3>~Ru`zA>i#-|%u)J!ZVdOdC^>NsK>Xza|OX82yf(K&YleyXn zT-yA~vt(M3;ND|SPiVABfBk8k1B(GaHa&s?eZ z`sruCc6a7ugq)k}{u4Q0XenPfAZF};=tz=;#2djKv>@?*ccn3`USBnytRlvI(@v^1 zn?QsRGeJy^E1Z-Vm>syys&-ooxLJOEVdoV8(Z*-qdN>ISi*kL|(%qQiHX9dJw_I!xNqB- z*{#c*!e~WR{dnvnM(1)p@w%7$@H0-PZr>M|R{`G6@?fyBgJafL25bt??DEsIJ=c62 zN4^TKcl%GE8-CFK!Om^32GmUuoJLd(wrdt0AIuuOc?&vI6KL)DB&QZ0AY;h$t3+yd z-`FsV+g;_j###EP@d-qJm6dC)>r130sx>Z3qgT5IhkQ~l#f+z6~M!3GI&>H{!++-B1Go$;On{`=Jg8bCmcB``1Q|R9^|MI>M9PDvG_nz9k zrGzlO7-&(^Q4y@rW4+4$U2RI~plE}_M5z=+q#l;9Gp95GN${Q7h^=jg6%Q0K?~M$7 z#x5x{4?{UJJNDVKdBuiQ@W;35KYtN|pTIn;OUJ#^p9hDaW24=y4!H}-W)bJwqW>eL z-4-Z!M_?{0Ts}dH0Oi-D(?-FvN5(m#Bv-wOH{-QQkr4B}3N7g)-~Pgqlqpf|-EfJ| z?BctX!zmdEi5CeK%_b|n>}wkb2n#=p`tEi<(Zd`%7f^BntzDbwKP9P$c=>)#f6C>I zkZTZ7+W_}4Rrj2`W#xaqii!G%Il}u=@zDBkU0|DQM<({e8wL-V+A~&vb%E~YHQIOE zn&$71YsN7qlHB%AXC|+QBKbk#U!1pyHkP&fyDL?0PgAaiWQe60 zqLM@d#1l)&BOtmwkcCxK=@UO~ne8SfSW_kebSW{Y{nUSqEWj9A3MN=ALA&lP)gb~d zM0j3m9|?ZX@CS>3z0=ESfFLO>@j#M{=csHOg zvCFrx=FTZV7^NOvUq*T$wh zrYAn6Hy%%9d?(G+U{USg7_v1lx3$xS?)qt{ChiUYTfq`ipz+FG&g0rP?`>FPby(FH z>GmXlVP*OErP5I74Y8z+(4(G$U&t(d1nwsHbKnX7rCf(+OG0LDJk8{C9CX}ROywrM z&0e6_=bDEHGJE@$?N=hh)u>nyA3y)WNO+M+WtZwNFX!jhwLbB^H5@i?!KjzQ8iu94@hz3PXCB7JML$Rln?2l`r}bjDx#~e1XWzeX4|Z=lQ=A?c*%T^8eGTN3AWt zYo8cqEf20ZUxk}wy&KK^`%bqj+>Oyqgj8n)bTyODKB2GqYVAh<+)SM+H>kL6wGdyz zKii*Z43wcvyzX5muyar{lBfx<|9L%qw2~ir;aL zaGPirS}c9lf!mD*4g>Fnr=T$TW#JV~2-l9Zl9lDTbO_98M0`QUq8wPY6t}_p+T7j_ zqVOK{zq8eaSCZR#<=kAO!3hgsXZ7w5DZqI{Sb*fki8D>EF&cn1uRAb$HxQTgW3U_6 z5N4lMLdFJaghjAFRL(m?^l<>{sW(*yy|R#p>JH+(L%M$oJ?VGZRlS5;2azad;;nk=9lgn9v+UxPPQ4 zkOEAi2{(2#xC-7>=a?76%<>By@VrJQxWH|K0=1mtS-d9edLSfG&--C^z^5!-Yap^( zMSuV);3EGC*mbc#{(0e5r-C!TN7i+-o&NYh9*@{k$VfQDL`rkyXWdV6@8HJ#lt{V(>O$ef`zRoo0k-Xk7AMIpK8A4a zT$5*K|h)S}&Dz@s=xx^Y5BTYZ-TAsuUSHlinHyr#Cci~nyC5T7fzek1Kh&BEF< zD|E4!a%frqMxpMCnzn0{BB}DSN|NI2S8S98wQ;mztt-gJqcqRL6jtq|bVauK^p1sj zt$%a2)iLKKllml}jvz+PcUSl6TK1Ty>T+=YOp5G6!U3IrRqC0cLf|S39VYGm!hOAn zjRo+Y;H=hk)us;E&E_+v&6-gy(aF+1jkJ(xaF#Upi?|^lALn>={-4B0#ajC3T0GdG zS%JMyL=!@LgzQ=4=lWvoih|>w<9r6M*B4ia>YYrRe{a(&NEJzvtxZPO_7L{BH1G7? z=7XB`fE!5+R^{K|9z7`x^_rTsSvS@LiV9tiK8ddHd3*(3B?i%ahq!htc}t^a;nq+*h0-`-mX}OG=I>3esUgwkA{mJ9d!fqv=?+w@n>R=0ldm- zw~EQQqS8Wow7@wyZ*kU2irb#wG!=AAaR9{k+^(!wI4`av6x~&ly_I<0HVE;&H0DAS ziI;f(@iKJ9Nj%M#X#5xhbzkujUVK+hxrjj1kqUN8n)AX7ge`uKun9{yY4S z0+zmL;o|m(K>h4YTwLd4V!oix<$hu?+6DYcmJl{6MWaZ|st13+B#DayumSW9o$@5b z#r@O}5AP)??xtS-%W^*M7tL%R+aktXLQ}>3XT2ahH!fflQv22KGg=_43taEz9!3dB- z_l$mfl3D>Gj&q}z!ULGD94=Ec+c>MFf&8Mv{P&!>g%Yv&_!yN(ch@UYch-pyv$GY~ z$;mZsLigO#(ltr_x4>{Y@vHtE0_Qg5eih{LOQBLoV~LbxvHqk9fmKyIydtHIO7Cn9PFMBs;eHz>w0s*QreDECXq)>uEQYU%oaZhh@lxG3-d06NFaY1eMCX zBq+Cp!h>r69S%ax!$x~3)$Ku9W7RJ_Y_sbv{DnQ<p8nWdg8-U#<-fCr{iL z3c*1%ftg}lafObz?=d7wxL73Cl(^z_(g&Hq?dd-M6`jfC;Sdk##`>)Fd{u=LHKT$U zJ{=V8pOreOgfSf8rNF$gXnlwMwqegJM9ufMB0kc2sgC<5*ePZlmbj zSf2m!ZEbrRVwankQU!-*0xsR$adbn`uP*G>$8jSES2rZ-N!8R0!c6D{5V$BC7KNz1 z@gT3}#t3g3jUa|98l?$16_p!|{5-umeE;F6-x@vp<1PHuPIK597dGIC0O#Ze9n)bS zdt0~txXa&khCiP52(c(!Vc~Y`1Me?!x`&<2`=(#1NJyO8#mFWrGDHr0b zu(u4@M`2fyh$}Tfv~00h8O>Pwa1g?4DmaEnaKKx3LrHNgCc z%aZzP#k$G?DqLPskN`{`Dt8&Rc}1Rn-lET3WYE=vXxRE-zcACltvxt3;hK4pV!dQB?nz+aGW|(JR`)n}62J4rj=0YH z8abM`1kp+s)yI#`p2@tCI%<_JB5}9X5)ue|f60LBRy=iMB-`o9qm{GU0(EB++|=#B zRGB7%F8QW4l>f=vA(x{$5{k*rH@?_erUp5%E=Xd&&cg+65&Q2tm1jtRaz)}-_5Z(O z6GpwmKG5O#aks87lbr->K$9jVIyS@na=h2mGO1^4-TqYPI09b`Sh?Ub-4);-DkW4!^oX05WK52&*mi2xxH_L# z#QYmbN=>EA&_qSL512EBimi_Vg1pt5L`Q4%flGpqzcX-22%Js^h`@zupVFW4W8`s? zowtidh-;!LT;KcUDv)^;83hhe5?3`#Ax$#)%tM|Y8l@@&guDalcE=lvPf{{_AO6Cv z{!skSmDlBDak~{Rz8^r8mi>n<@|~qGW%g0%)TY*x)qi|T0c49M7(g0Ot%>1VCE3PZ z-Vpp>keI&_#c}&UUPo^DmA-GfN-?Uc-O+M`orxPGOh~kFPWa3|X?n*vM3ug{O73RmEmgy@$FsWy<3{Ingq3 z^xAPdid$+$l`#&uJd3!TTY0R5zfhfTtrRZ$Hqv7sF93Ds`d#Mb}PcCd6tuK4H zTKP>`NT_`;~85-5Jyl)OzpH_1iBWSjhVcsiUIU^?s~#%d|xF;qp{%%&3Rx zeoK|f;ByXm5G7%8u{{%DP(V|={}P_3yBB=;cnS2()wQRtU1!HqQrW}9p!Zrjdw*yo@V8cZ(BGmRXJ$b_?D(zC{SBxZ+uR*N09?jhLA9MLw`1B+#psO7Qmo>mclDeQ*yCKRb|liXpmCdncSY>apAmK= zzsv`!c6vuSoGtP#yp~Ulv8WGt|Csigwh5N!0<9d`nc0QcxT3OM-_*~D~tr@ z=0mX{F)LQ3mp`2uZIlp}L3^vWeJY|50foQ$JqJFt;2gehJCYlM!kZ&br=){@M zg<}a*Pk6dW&yLDH;d*JqZp3Mszun`w5j*=YO*Y+99q)-hUS!e*SNTho8Bf{K``-ZU z0Z3f0S%^3ZTskFF^XOrynPe5|)DkyxW*n@7n7B{Ca-A1`<3?aj?VttJuFAU4Gsjsc zdFyQl@Fs*Daw*OKhkN$+;VlvQIz4d~^;^%*ES?8mj?>y%#*iyx>BFISO=M+pR& zIoUzDo3OUAOl+dDpd4LU(?kSh7^T)A`_U>;;gauNQ>U5-98Q!xo2zqyo>v@kc85ig9u+fI*PewV0Ty-m;# zd&sQxQmxC2F-Zgy^4L-mK7Rg#gHkRb7&|p=XPasW|pk z>0a=xwe8&AQ+z4vt-3c38y&i62Z8xbu*60VcsyeL?#(S|L7E0D7fO>^Ye-=nBhoCJ_`G%{=lg+? zF_A4b)wnYmI|#ud>c%!S{Sd^65VD^c)P3Z#4jC&j1gF zIB8V+{tcyuD|~_d%omm1jsSkJqv6BRuNi_HAv>S=sV#!R5qK|Gy-Qh9>MA2Z_#|3s z>q!L>gce|WYg1JvajKmo#}4<4zvo=ytCAdSgL<_)YHEEefJ4dZVAtyqYuLL@mZ2xe5L`$NT&xEkqN0e43?;M$CskYlBmrHvrrqdg8I14WQFSeUt{eJWnfNA zuPJjcdjWIAE97FSj~C+^-5J9bNY_KPrhAj8FAlG&O{wq#an)i;#15}1D%BdD@ethSoI*AK@m-81{5}MSE zILaE5lcrIjh5zK+I|1?{pfmsfOE>Uu(#0oAr^|${^ba4r&rp4UHsx@^b=F<-x-{AM zIy{Xar<=#HU*+)R`60I;>qy=vaun3?WxpF#WMeNhmM>vL-3tf3i=&M(04$-O=QASM zGj2VXh@!I3D>p29c(Veq%qdLSTdHK+_piS~EE!iI644)KJuclJzt*Vza-H#S;ABk!wDRWE3uOn*ypDVd;GwiH{oAlX{C?)HT` z$>x)8S^5p7u`|{I500OMe;+6-}YqSG! zPhYH=0SL8Qx%HodtF7|IK26)XUo>T|_{yPsT1HDQEsJAh#@sxR(_Lp340o-Fy zc~VFN1<4-Wup7Fe$UiJO`wSjf2Cx9r6AokDFMRh;Lle=cDIb+o#bQSZgjsTBT8U3R ze+2&fqJBBTf3eXBQcyyie1#bD7h217I=kC@RRwMB>wP}r@=5PrF?Ap@ZVJk`-63O~ zhq*FF8j;@4o)rU#QXauC-@hQK8FH+W(-bl()2zDT5g5Pk*sN3FKc`5#Ye{#2wOV*P z_|e+liD=l5J5snB;KMJ&uf;&wfE3LG^yq}M5y(e?hAm|anHkLbUaOWYm}qJmmoRF2 zuv5uEP-LtJcHYVAf!7xF%SuZno%;r#sX?;gd1I54K!IUCacQG7?QF!V;1LZ?$IT(< zeLVmi`i3HPuerR>QvuO?L-cqcb>nxf8cj$ca3B9#gGqC6y3(-uz^v(bUam`>foW0T zpwI~HaK(F7<~eiMr58aq_QAm)B)nA2Z9eA9TG!-tS@VR5h&OK5bk}COD|MZjw+6~! zoUlcWq$Y8|-U6_3zt3xy%AKV{CM9+>ypNdgVdnu>&kh2O=E=SwhhR=}X} z^RRG)5|}#wC(|y@Q_FJT*)h7h-RCg*p{QwD&jepBaJ)SPyI&rO7y}gJW{+fELeF8X zJokR!^p{2vY9Q(6cv&+)^@OWKC_#9#rsa4xDI5dH@)7!(gHbE+S#DX`InP9`CE;-~ zl!mYS>a3r+?8N%Yx1ib%wNmOsv2>sMZmZ44xn%BMOH*slomtuXd2bG~<*o|Ai{?{k zKT!FYr+x~6a%8#x9Tw>HOZ_U-9BGb5Z<8oT&t z&BCQ!4}~L7rKQQqer7025dG~fJFUZGa8nKZPxtcwX#vlP>#?)3Us6Hj zaakOY9*ONPAp9Odh+k*w9(?^kg4XXiw)VokFI=#G*v#6OOEXu~a*W_J*Yy7wyY4`! z-#?yJYi%xm|u>z

    w|+jnlMLOzo}j5>;VI^H}Bj`^-%YAdf&>@ zH+4jR9$P6@mW1eqpwhFEX*;C!PflhI*by7&7r1)-2A62Z0C%uEO{=Jq}l=ZbUgSNkJElfzI!En{mT~^mP^~4$JA}|;JlnZT|>sbMTxJAzry@HgxVGn zmiNxXk03Sr7IpJGLJDM`aw*0dS#L>Ihj$sKW7=i)U;$2%=!Ewt;i_U&l2X}BQ zp2EagRA=^%t~X(tuXQy%?~$Jofn~@P7u~H)=rZ=+Hlf{*pt5&3&xqFmJxJRLb?Iqra(~7gxLUKj+7&wpa)m)9F($qZ<4_QrG6{N*BbQ^Yy zdmdXd4A@7X+aP7U3K!P%!}l=5pO&5=T#YSJg;*b%j0)z=AMu7pEar6x2bmF z!2IlkIq&)+0}Juj9Yu6Md5RfP+{}~nV}5%uSB&8^<}ch{TJ7PB=Jk+tKuTRj4pMiX zts0;7Gx2*3-jJF3W~$w-yx9E_^SO}NGFwc#rfM{Kgt|;YB>Llp4BEJXjI#!=X0gKG z9sK%jO-m(_Wtu!T4@culFexP}Q=K}P%gM(sm%rI1E%xf<=qqSQaOZv)8fzr2cZQ!D zB&KXL6noD_(pZ8p)jhyz%gK8;ENGAFO+IT5|B@v0Oj(z{%6dS%itV)TXlcT4B_=ZH zOvA;;H^GIFZ;Y@iy!%Z9GHq`Oygr(t+3CSvW7J{ovF|?Rt9@ZKMH(_1OI zGbiQKO>h~V*Ueq<6IYQ9!_2$Hw)~Wq!8R0rc&G*>PVcMGFf0dMYTA9StKjtGt#;nhfR9K>Gh}I1QSy(*EYYf+HdfQYv7lX+T=c@$GK4njDTL< zwfM3>VPua$mhZgfLf^nEovC?`soHBtsvOFtGfiJ#Y<5Gl^XzOpt(8_!%g*6&NBZ4R zO`mRI!KunoPyWe7fx(Xaw;hGD4k8u?vHj1SQ*>iw>L+u;rzI^!`y!{m2Yi0#_2<=^X$!J*vbb zT{(D^1z(;0jo8O;{@~rv{S%MNX6~uKV2so zG2*<%NME7y+BNf*H%>p(!~73{}*Zo`8p_TnCZ&QqoiROqlb8>_lr##HmNMxILxH$ ziwyKJt@#x=y^~)DTo3`>FlTx13%Dk_Lxex5Z8JMNdue9UPh2NHP9z8u((TKk-+IG) zK;xF;5$_bk4`VN>W&9kBeUJ+$3+%suqfqR7U;QeJ8LQg@lXPh)zaCk3G1RSG%_H;Y zTb6B1`}HRLgG^37sw`Qo-INSoHNze*ulv%1Tg0w8A?oy4e?fX+;)lgfXQxY#TcFy% z?evlU{aT&I5-CMZ$Aetwo-t$Km!3)%sXM826}pOD>Se26QKK`)Kz)FP0la3JnfTx zWIHupdDb0$%s#~AI-kY$$e3YK(?@4fP4>{012Wmj7~^9lC+->gWZQnaOK-w>#e%+S zT)1WkHBMV+u(DmZo5{K@v?V~47D!Dg6otnRFv+t+sfheN`0(VVp zr{3Ga{f0kv1ZjhJ36PPC+;l8ZRbY!!FbFzDD8;Gd?J<4_QcnZAl56jD7Vk84Q;Vi) zDSMo#Z2Wku$#t$sq*8K-!QowR{L7Iv%3ias6H*P4-i4=jwFbK8Wd;ZB&kd%M8R=x) z3t}j<@A8#eRK8Z_>K8<`J<7J%%|F{k2R~YNl&-Rl?4kr0kKGGd*FChy2duuDRV3Aj?Siu1)}fo`k90J#8fM! zC~)T3{hHq4GtCkxL(==#gtkbs@DX-UHkPozd{RjQ)KmztROyyRskd`t1;yuh-G#w&iZ9 zPlFN?BXIrBGokEPSp!YRDz__dOT-2FvG1C)%Lscmxfft*OGmuk37;))1>l~B+v;Ch z1%tqqH#gz#`?rPV+aTr4RMR}a>CI1u`7_<}51Vp1{>3mb|4(Z<^Wkihz4~f5Bt-bm zahZZvi&X0SyE0gdahi*pbGToei@TlJ2wuGM`8&M!WUN}^)q5pA93mIqguN8d>q=wa z{$a#)y8a%xi!A}u_O5H-h8!L3+X`ey=bg#Y?k;d>%;PgB3+D^<-7Q1AckMa|E|z&8 z$txx|bwIrCXy@msH;wsx;P_lsstq6eh#@BXh_tK+udy9H({6T2R%9Ho-b!mxaR&)& zRn_(*eWTTENzj?rY-F(QpdL8-j=OSgv5T>|(usMXvdn;k?4vM*d3}N7cfr9~mROEK zYxC=YJ`Po8=3qeMGC0V3grb)2Ms1EuVweOWRtyy11Npl*}iJHLT3b zPO33xDNea3!PY!0dX@G-_CQwRT|k=@ctLextG2ebizKF4sJmGOzVuKJT_j@`MlP6ac2DWDSICRp^rihEW)C6}{}58Ni-ucvhr z+{);(>6i`ev+qxe9kD}0YI6(rNtt)tlfaEvrtbCj_O#(t&aQC-_1$&X+<~cBZt)+e z7kXLN7G$fw;ZOg_mg`K|CYU-)$KoTlG5m`Qc)?@^feY7OdnHTPS1 zQ%t>g-L0AiFO~skfZje-WBn@IKVa4<51np!(airnA-3#ESSJ{K+t2N3ESTi_?Xlfu z&CK{$`FdZ)j@Z`^=GeIGVn^^Lj7om~t&J2dMQZT3T#0s->I<@O{15`{m5Iwz#R9H$ zAYXiL=G*TeJDJ?f0;yqdN$J!ZbNj4Cc#{Wo#qWuDd&*cD<@AykH*C7#kqIx{hV+# zJ%>;W@MdpqwUdcGO>1qL^lA8{*p`c)4;>wK+VyRgN%th?ZD(_8(ESIUKGjcd=s zv3DdmMZE5(+GOt$7>Sz3NW9!a(+3y&>A=@_Ph3bKyWA+~QY^}Z`Q_!8Z-=^fyv@Px z=YnP>z+n5y81`&Vy}XZgmEb~9CfjSgi7;fOjmP*GLkZJ_HoHldkou%H###-GRpzjG zLBnvVHpgxMo0W)DnoTu`^M%3z!}q(29d1c%Mq05J(t+3E9GTDR5f1bhM;Z#Gov3uk zv97g?Uv!9^vd5k-n{lDC@cdQPy=8!6@@&1<47k3Fe`p8t%KrL-YuPS^<5P>>%AFPx z#f{EP*#7QVZP_MP_R2!WXvtxv#er_Q9rL>G`+IG}ol~Nxr);+)3th=u+-CSPYUx+?A4cF)SBb$MeGl^fsL z&uMPHP%;m0P_%0ucG5(7fO+_nMcaDnm6oh0yUg=dZn$dkg6p98_9JY5EY0uFH=r6d zD-e6-;i{Ta^o67NfVVMby2Z$+1qrT#ck`&-*Apu1chreBDRW1KU03L2pVcQ9%@w#DT?_9?v7 z{POv8)99-7X#N3xPj!`;tPwR``xq{f%)oSG$-GxnQiTCql++*Q^q5tc>7Vc_>j>-g zO6oR_sZH)pDTC<`MQ%aPHU(^<@tS59Y8o%k8|oXtq=L(dt6iroBv_-7z=;%zy*u*b zniI?6ZeNFpa3??Aj6tMdMr3SAeKYy|tRagN3%ku^W+<(NaQ9WQx?IUJ{kVu%s)Y~c zqzauiF3dL+f_HRfUo5wU7lK>24uPTSaD!^2jBKv%Ct z9hZZGf;mPfKF(id`?IZmC+sl!`R!A3*4V~@B^xZ2xAy@=n(Udc`(hAm>bDuH3b}Iv)qN34v2AvkaxSK=*eb8wx$(mC%l;25 zyDWwWm`eX>YCiB`P%j}TkrncI^Pk|Ed}W}L~pI3n_%?*eiJqGO(`r&-ehUf z1U2Pf(q&up+06%aZ2AmW!Ow@7F_(G)=OjPuYk!&YGm9^Jbr{da&eE!yWK8U+3L3V# zR-7Gc--q(|?qn=qnsqmzkz2|wTq;n*l-Qge_VCM{5YP1$A6Bs)=3D4$nec2y+5IlA z+NwE{aSvU^=R58)QBv(M;N8zyrMcLGl~sw^%50BKU&=o7BgfXzVJUgfDdQ+3Srsx> z_xaOuGkm4FDSngfoTvTYD2aE3^iImnYxaKi$o`b>#^~bTsEBUX9>&JGFBY{ZYLBNi znkDF-aG6RjETC&ZV@iF76I>0EerbA7!TBrTY6Tx zZ{B&?EtKhE!O6a_C}DN$x^f|!{iwk&-9pXc6OZb{a%UPbGP*3&c~4V$ZQERR4SxPq zV?#J@B7<$+b!J4aVry^DDKc2dt~hxZx%3fvQFHCY5l*fVVGzwQPBFcxx*pqVS7qg) zd0OdV>ZH#v%io=D`VKVJ?TgzDmi8C(>=2IdG{fpDE||@5G<$ugEzId7>wR5`cCQ@H z`4i(lp7m~^d1T8_V)AU{dHzhp60}|Ul1%6JJzRG<O2`4?^d@ZcKUlPEXpHzRfai zHTG6_xR=v4@acV{PWq{0YZ~*6_$&sj)d}z0Bfq~^{bF=G?VrkhLGi9lT(S-(^aDfk z-S50A`89PGPSv5EJKzmQpJH?`-S#d&b6jUO@lohNY;n^yCQ+}GTrlRe;krMm+wavi zTDZeaHgg(D4G!FpWK1-TqqRTARaf$oA`!f21;tmC!+esfOtrx(eVo>N?uMd|b;Bk0 z(%e@GfzeS~QToF34|#1&Qe$;65?+g^eHZOdCYFmOjxT-H#fFTE#2TG$Ow4XXmdO`U zOL(LB9D4Ms9lnNcU-C~(!m(&iW-w z?60@}OhPPgUy1(dxwi|$&T4(Th}tRP{HI#_{9q0ZmHb=o2f1Ds{Yh*0aT2KaVo2jL zlX>Iip;@YuhC+_<69m4Fb{oi&=Pq*+enNl$BXwC?ky<*DfW1)zw`ObJccXTLt zx0@Vz(Va&A(XK(S#pL%4#=bUnp2vz&qZHjC1D>qR`*=;lITST&vfBjaF`FA%?xJ|^ zp|0FRq0Y{VO-oC%m_6w0_g>ukhGT3ZAwArh>ojm;r@83LU%Ik7RL=I^X7usCnp2*w zN*$pja<49To88Cg>P!pRG&75woY>@>-b>wQ`E{P2ZxrgZGM*T6vt0BQ3c<`Z42KOQ z@=b~HIHDs9lr8OIx+h2MW#7#3cp58)pFF2YOnGi zUx)0k?icS%OF?mrdhLlTDQ&cvo0-@2Nr1I5`q}pmXfo|E+T+TSQIa~v!>?Az$IxgT za@$Fe)sw}{f4rLiD(ZX@HK)I6(N}MbOKMFjB8A`Ku)peD&y}mtry*PdUu)nD z*kr=`WCx7=5#{r9j)-!Kp^OgqbYUS^ldJYjLM6VbcCYQLTNZL6_h~OJzJln&eN%*U zq;(j}HrSS(DcpY0sm2R~tUs%g@0OomXL7bJZS*xmOEeey0NE8BX^wesVVY@um_IST zRxRZv%(1`QN6MEv1jYeE2#)uP=yt6$kU-eV423=k|)aIHzouFS2gMy2f|(9 z7KL2GA7_UKd0K^7l%h4Mn&{fVmqT{PX@SpPWK3IpCE9&xs4bgwQ02S$*kZerQn~9J z<^ym+cl!ss3A?sbZXYTrv#c-kOushVXy1;?^foZ;Z0{G(xf7#gbO9xzAD);iZEZY4 z$EeZT6+=`t{LV^Q@EiVlvoUM@zl` z6ofN7To##Fa^m1y-URsm9m)fC6DjFW2bz+(?1#51mHMVym=^llgqAyu%2=q+&&^fs z3xn$owUCj)xb~QC#;vS;UF(LCw#aC>c6SV_tIs5)U`XAfd+f4VXFEuKJ=mLS{f^ddPE)_<<9U&c{5anAvrmlu@9$y z$vr%7k9C?%3B=y*Ynn)z&Ju669OtoE^dISwsRH0-(Y>9hWg=Xrb0>FvWEyc2Qh&G9 z9j57REx3ItxBd!-|J~9Q6RWJrTpy;iRovmXS#Le>gSM9R#}U*&uHZG4EzZXgmepmc z#tt+OJw~bAn?beBFY7dIOFv?lMrT20^(*lu06L=my)lESp>_@;@2oM&Tp_J?UQccL zF}d*8si*x_a!WnpRkXu7B7=?TsWUOLb@kji%?}3JGjvX#BuF^#4mux=wro?;EmPMM z*x#!@jGdpTv&j~+?|;cvom{W(kzkVEoTw`|KJMi|pqHY#G&CJ5l-mjKE$}w>%becD z=<8oLW$(}M07Ol)AksQnKa6ph7@Y92$v%}mBr|*&(>BVgUuYzL zh2J&P`;x7AhOOjWpIc5XcZS`Np}&p&IJ?EPs}b{6?EGl%hZMu6v-N)nTtPS@9IH`Yhm`+7y zciA^4@?8lN@2N>LvJ^0C%u1i-v2E-gEo7-SMCB*)8uj|OqJSe*dT~V&uk?8tln3k^ z^pM>aOWgKjmpsa|G1FPO(tL1idT)~URjI-Qg6>=Hdsh`S6w~5KKXba%IfI_4n(IGOd2h=~2W*15FQX6T&NCnLo;|iQg=`TZ%+V}I3LFavq6B{) zr8s<;@M*l^Qy&E>_{;-ZuXaS%j6`7Z|?~Y60X;4DyckG4CCg z@pBWpgGzbV-&Q>u2P_35iBJo zSab?F@K8sEf0Uh8GPD#|w;0J%loJ|xE*q^SeHbX!MQ@$Y;6q2<9cziuy`i3@?Ez|K zP-($OWB~NT#MM064c6k=4IDuNZK4`P9~cN$_@}j^PJe=BB~xekld_ha(o!>T+q!F^ z>8Ow9as3Ka0f`Rykg_IOBGKutzptiUM{US7M5!hZt}($P9<9ms)ott97zKMy@d}o~1`^OfJG}ogel_Vg1 zjYdLSAt+A3Hz53`maJyuTz30EY|7S^o2ydFeX}&d zhw%{`Pe)x5&mMZlF`U->N3}{EwItKar-8B^XBu-K==st_;j)Icl;0R%za}h#H}dCv zYzKkdq}qCWuctDcXl)C*3{oxjjKucV$GE{QQL9;vw2>r&17kyiveVA(NR z7W za^}QqhbbvHDeRy^uDe604xg+Xs;IN7lVvh)qC@kc^<1k`?jd-0Rao$aty{q><=c-< z%jo)!@{`6gl<#lkMXUp)cy`9^2NDv#uT+bApEDahoTRponMFExDKfW#JDlmj0ljp0 z-@6z`n?CUL+^M(SVV!Pp5nr(F-l{P8$;n9_jE4uf;29CX4%@QUAl5JxGBRlYh@D_X z8o`8k1EslnH~WZ?$+&3MDjL7}i6pDyKkidm@4oJ6_Hvj_&<*K^&v#!}eD6^&$%He} zW$)CrEAIM+DzC@|bOPOClLBu1HSH|riI#g%!d|t@iM6gOr9}#bLNz6sX^$NOX5mD! zP4b^abej!F^msoeD=?(%NvGSD9U_yz-32*GEk692<9pKINemCSNZlbomtR|HYKxbH zGPW|y{e78uw~M<4?jdnkVWAuwtK>2VDc$AC+OlYU)67o|dfMcSzP?a(Y&vr~Ymd{a zwdEPegf|d<-EoMPb~lvyU{>P$t2x8jRR^68R8{?Xf>VsB$BxG~_%;CpIOvsc19!JO zM(d||W92PJo0j-yMbp%nRg@suUOFCF{TCu2fjICg<;3Z<;5>0fRkPlt0C+A*Orn5Be2^USe!HceJnBCSq`1TWKKMAdWgZ~7|D z4T!`1DQ&Tz&I!t9^%gXewNG>sYI$IRT=Z^xBR1*DcVL2PXeh@{cW$Im+;`~@&#-ho zFqDDT@ALept<-O|W@gW`Ag3Djf6eQ6!vz;-*-@#pIln)91oz)hjYiV`bY6qSJ%Ns( zim+^*ctNA_p$>u*5=?cL-8r(hywEG9hca?)`!lQDY}7Jim82hU-mJgyq_gUCSPP%4 z>-MlET4$=*oIjks(IzO(Cfjnj#k6pb3qr_wt81&q8xGE^eq`m2x_SH+iMj#BShvmQ z58UCK@qcUIH?r-)ma~Lk<-zB$i@AM#JR=K8hgF}V@7~uVtKf32=y+TjGAc@{E^Wv& z(@pGEJDDlQGA%3L;fu4tk~ z_v~pSG0V6{6RaELqV1=~HQP{qiN9g)a(|N&D0LYr|+T4u=ck!>@&EZCOr)$Ad13q4cr2u~?o#l%NvH~W5mc~V*hXH@A| zn07SU)Xa?r7plMY8~G?1yhK}kgE_@Y206ByrPh2rrZuw~R5_I~kLr6-Dy?*9x_>9Q zWgjm^VSjC15@DWuCVHz)(ti4N3YLaLm->o2;|WkRSOo}c19oB*V|hVqHX^~eM7iSR z)-8fOw;CID{oKn$=Upxh%A7uc8mR33J>M4DH6|CE`Khbr=N+_Xd8^O{5O^>S-Mq(=|}rM}theT!yfPRQ_bV=87avt_FA@l|+w$grzWYm52NlTB!IDrKI3 zq&ou#5Cfl_whsY-h&PxXvTPo-8Gx(gm$SQ_2@IJovm261#hB^JsD6-3F^NzT&KX!? z$o3=cA{#7(Mt~r0%IHsJv&1wr^Gw=QSUu|*{qde~_8pov{)m3Q{rq49Noy%1v4Asw0Xw9BqCaOtxodnfd7FDQw#Dob(IiV-!gn1h3>E zfEO~k;0C|f3_J_;DDAS;RwEVphWb-4c0+{8pDb=J5s&bXOh26?vc4(-x{JqO5HA!Ld?tExuP5mqnHc?IP{B-dslNXYQBCYGWOg(g~6vsnuIfu zUauPklT(uKk$3b+On1wgS$NOTQJQPng?h-By;M3Z>yPbFs@@TsbEMUx9Nru1Z6bvg zGt)tiV{(M-dm}ME7&vpbk%0PxwhOnKe7Jck!pphsD;bvUK0p!{`q${xmmVw~ly25p zF!7~Z-RqgtE5mdVt3j(?R`GGZURRG?&c!w?R=3wpW4>S_XUF1&L^czL!e~{&urTpS zGoIWc>m{o~Bg&qCj`_(S(SQu&&f`?WHXn!H*>c(G>}jcV_|)%|(22YoLmnEVTpl4o zg%Tb;pKRwtMac0OhWJI!z1kj$zJ#bQIH$6&@@tB*lm~pex3|jQ{(O_5V1jJIC&Ix$0nMbg5 znYgHqk*>e(t&HCtG*vYI(=HQzn1v6P1^PP-wFT6P)KSu3$y=`4t-uN)UcmpQf);2v zmpt=6vE#lrQ~P?&+>pxSs-u2$-z39iSgai~rlmN8X_@3>9Iq?EWN6at*9g-$8qC476S8F)r11|>+MCZN_-i0;c|TUs}TW@V2A#pUI4Nn<%3`|3{8gD zNSR!5yxnDG`#8AAM+D2DA_RP|)s6ZF$W|+I9R2=kvzq`58DzpY#ol4!<;6XZ5ReTp zjJ#Waseoo6(F`8LhdBK*H6??u2ny#z*N6s(QDhW*=-Vk|k6za&2i~?Ss`fnT-sFwR zA~{L03|f388=N=~>Q;R3OV2IKfSaeCdKJ0wD!3n3_^`~}cEJ5dCUFl2sd_H0{^SNYi3CAoY<&lVw@H+skM{%M zo*0vdBXDYoe+K1owms`HWSOGlY~UF%Z|+0rE(*1O$~-|020gh`-Z7P>@6SrLT3&~& zaYG)3KL^ppOT(m(11mAm`&xaz7GbKp01V;0mG`xUVOX9}l zPeK;nY-$FPrxokyzgjy<(p{oVfFAYxf=nJc>^Qq>QvuYlQzVIerC8X>j&k0ygE2OoJ|>ZVvmDuTt~6b~XKD%Npb(V<@gWsdAAm_9zGj5Qs) zsbriblB&**qrWG`adZ{U|IaMY8bE^mA0|3%lel;4)r2u7omnIM5XPNrG@g(YQQm~? z@2QZc!SuQ}!oK{ErGUW5jOxY-{pg@Iq~m-)IT-^8xrgLo6+H^fHd?3Cy1(FeOZ;< z1ycP0FMcO#JqhP=)M&whn~M(1NeAjA&q1d0SFOUe$_#gbcGH(%!Avmiups)e(j@JEg=E3Cosh!|czm@a0=_% zuzCxm1*!XId+xY*-KU@r0r_!#;ZGpk19qI$! zPWKhMaI7Up4o6=CtPLD^VucsWE{eNl>E*mbF3N48ly*rU*LCtLux3pgDMn$O}Mg zHSg*thV#$a*p`_}h<5(VOUU$KuwNA9Yd5vVu!+ivJ_N9jt0RV7wHjcbaUhqFhA*kT z{N)iR>6)=+#2G*OjmeB*gZxKVUet8`thg6J&)AnP4uDo>7< zu+N|@j@tq|NB(W)!T%S)@cWGW0YZLwcM>G2%EP38Q9 zkiq)Ea4yM#u(U_?=GdeW6SEYgW%SKLwHD2%NijX(3Cvw~eW-{7&b{pr9jI#n$^r&p zb^lOeJNgJ=Lo@R7b;W+t?LagQQSZff!685xY#A0Ok)`{qRiex``X6@tkv zixq4NZ+otn4OCw1L^wJZ60=RdXtvDiXSRF}bgrkBa%E7IA z3F^-#F~|E+!cLLs2=0H76YxCa{*+J)AdRiVkZ<|C<`cK(_HR&2s1|B1o*-S-%9IT0v_feo}C+>53DhEb#w;9tZ?- zF4W=rU^+YFNQ2)0as{_T#6vU#=S|1dJEw%AB($PXXQoruAS7;Y?z>H0*Yi_s@>>9R zbQ(DDj^@&yIKkDRvql`t3&$YBVOcFt_SY+D86L)Z57${;JF-d((AQ~{&#timB7@~2 z?9yrc{5Hm-x#zN!l5yiQnDCJK5+%3%!+(=Jypi5Wp)sp#4woyWVqeoS>IrwuU*{2*xF;mj3=gqCNA;R;B+8^Kow5TY}faBJ9!;g^& z#pdwxs6H_vVv%)0${as<9K-^VBxlq9Pcf zF}b8B0-*=u8hO+*#^X={_vjV(LzK`?&42+ahV~EkTOY@ulyT64D}V_P0rHC|U)Jr(ues1>LN-Xkd|0xD1M-XEZjj$itbUiYmpYhR|Ki^UUx>^_zoG%9u#sHY9z zOcN;gv{QX&mp?^Tyr#sk+T53^2AUHW1n^3vNI$NY(hze~v(#va1diF#93PO<0U0bw z_@)C@45pa5^XCdBo_R+eh^9R)z6gDWih|pd6dA!Y{=ZRo0;9I_sc~Grh*mUgXq7Jm zf>>eKMNoHRyv?Wsae)lhN`FESlGBRV|4#ym&{BYffQ(h){0b1`zYszfr@Z>>1YHA4 z2c?ilMfaU3jG42{yLku#al}t@ZlV{#%l^N6Q%(loRH&YulNcu=KYqb=VkGEMZ-D48 zHI>rc-4{vQ9Rn(Fkt>No|52@tTepA%aYq1n7<5=z#FCe6RBBjFb|A=`oKu3(V<5yw zS3Sq~58lh>nXeCXIDTuD0@bOL)pLAIWyr)Vtjt*N!C<-;F8|9LynLX<(m5^i)Czx; zogM8qUb12r*dX(50ZJfux)ZBQ`|HC-Ahc3OqWQ{7Y+%-FnnYu8nzTwV6;K(IIYibw+mryFkMUGUNhGn|=0@4fFoC6QrwyiZc&vXkFmTI0BY8NZLYhsZlo(sP$|gPXpUd zDZr!P)`7cwv{$iq)xL-%5>QGVngUnQcuY)?lDGm`{{mANK|0B3B>A|jQNdrs?c%>W z3_yrFz~OqUJG)mzedeQM|7ymp0&eiH=b(Dd*AcVlOtqfTQ5&me{f_y!SvQ8 zBfl7u&fL1wRa_-7ZrRJN1Sb%2h?_vQui2z%rKh;Qin@Qvb*?00$YM8IfjVvd{0RRb zynVv~m+JZN?;d3Ak(tLP=6>fSvK38w`sDvC&z1l{`A4Q|h2ebh@Dyr~rKkI)03?i# zC?q`pR`iurt_?Q;?eMOcp}X4bUoCzKO-zOiwAOokXb-9J$Rd6aoVN*dEcq|%Kkrhp zCe#9m1O1gCX1|2_1^dGkO&gYzbbHG8L)#!KEXlsQ7B^Vde^~3*-}Dm#AhaK6OU3xU zD{BJ7<9wUe2!q56%Ts{uC^U3^(P?=%72KLQ3{T{0#`?iaW~;;Hx|8lhc?bFx_1=v- zvq1HromMcQ&jUBk@#!25XPjNR-}~(85(i#Q2c@(S)tyA?KtEU_@^=20H_QA0?FRaG z1)VUO_ZD*FnZw%`xGSo-pF&GSaIz84`rwx&CbC53s5lMwqaLK*n-@QC!o;jkm6o~? z6Y_ISY8vk7_INtMFt~h(UnaVfq$ZI zgQnTkFXBlkJdxxKJJ%_IMnJ%HUzC>^2qTWPul+VL%ExOd(dhqf>rkY$94~?v29SD- zl_$~9M)0t~SIOucd|fPB|Fzy4#u!+=LQ{o0vSL*z8nVyq{y+M z=`-^n8Dc5d6Cp9BTjV|~V7VMfl3IPwXF$E5xi?Pvr)XyWsahV6JP&w_f=R#5z2<)` z|2PjxD1?4Qqs+2Co5Ar-`u0hm7SiK0B=r4P+d+(qkOt}+SG>AGib$piUWqi8^^V|^ zK$FuMY|HAl&~jlkJjU_&U!@7GGD6K!=>wcst*o+1e#i@q=>4}RVzAS6-AZ|!F%%`b zwtX7}HRS3+>smLQ4by`PCkbk=ypjn~ zg7hIR&yCp&JOwQ4_z)7!FDH9t&jYJ-?nkz)uOoQm!)GV-=RsIxq#;1)hXK3~@sVt_p4MxZnZhcEUhr$-% zczL6(MsNcJDazE$8%Q2S|Gwmi@F3r|{zT%kIfCZ_f!_WE1jBG&l=~TB`;dj3ksl^W zQgJrFqU1ZnGXPs7#L32hUk}(1yZ#s1Ll5qtq2~qjm0ygck{N5n#Xi!2*7EPwoj=4$ zil~>QL@W~s9lU`Qg^h#af!`ymU3;TxyUIjQp7z8y$?DkJHy)|NX&E?fIgzxmgEv~Et)TqBm)Awid6_I%Dgw~ijg0&F#NU=Na00~ax zSc^YaSOK*mYQL(ukr&K4*$KSUU;fTtH6C!ZA37S#gYB#*EoYy&A01%PuwY?{EbbDT-Frs$Yf zE%8IkKaiGTK(9>#eCKw`m-E1AINAPs-V|R>a|bDEYLY8g|6zG1*2}iBz(;4@Ga^}N zfN^A_9$0`p;0bjybFE4V$RUIVzSZ2tO$rrpYJDep3XT;dRU`brAYkR?19GutRayl@ z0S91;8!Rh-2Ew9dI+{LK){;wlHM-`w=4rqKnBrd#gg}^LI`V{iWV>-N?ua}d&;dzD zQygp8D&tB>Kz$78-s@&Mi)$eG=gBX)AKg9AcUAVvkoW|SR1W&MyX&Ye!94+m$kRao zpm<=ndv5LaF7A1~ZvJyL3c%SHW7ns6>Tp*WtY+`(b{%9E?T+9}oA+2qvMzqptx|1U z$~m1o?B(ajl-RwM!!J%y?Dd$2t4R4-MpAKu9_nqf%2l!;#En3iKLM=iGAswz&imz{ zYMU^w1u;IAk*(lJd2rLFDqIeQ&VvQ6&qKVqDgyCEvy>+G_ImyaV{K)C5cBE@2r+XW zNvN;0Iw>kHO9k-@M-Moc0?RtE=L(S_GSq1#9LIfcZRG3>zR7bU-LAVV$!>bcy^a$; z^mZbU@vO!7;ZmpVDK4FdL*n23O4)nAbJp55qvMBVoX201>6Pe3o*HbACPg2KG(*?h z51sbADl$65kUKv=&*(Cg#EgEy{BF>)nmCCEUjOV)(xLo6s9&2#khz2Q4S}OTl&XHFI zL105ma=TK6apOSq@SdHMzmo+0G$+y`#e6{utb889|8LZUKwT28%ut4Qo}Ff-Gx1=z z(L_-$|F8<_+(D%yHBa{%S(|$)4f4!jN*(~6)UXp`pvpvk_@8=ZN7X@?PLpDYycZgx zQ+yAMaIi6--4{(8{UAvq+X|(4){$vhQd@`>+!D7gK^ZhtBNrEb1plPiXD;*DlT89<2{%C5H#0e>99q+w2qLFBu3uHRX! z@B;urane9<b+J+On`y@{GL(#T;aF~6lg zR}(pG=QQ?gcGuvrjCl6aGm0Q#^`j>QPi4g&n|^rb{wcy#q%vFk+CYX7ov&~N2W%9o z{P5~qj?=N4vl3o2Wp|sB?WQ03MTcQ>Bm9;+CQS?+MBBccyk#fi&38i4AePN!3BGcf z4NgvZb5tMKZ8qouXW9mfe46{ndZ{a?@f~ywU~zFVQPLUQoROV5JNN)K`z)7dAU4F; zD}&c+EL-_gGKWcwq`KwpChXGV!kpvTlP|EkmGeEL+Ewa&Z3#WC;c}^0@$>38RS>Hh zpI~qQfns?*93bdg-VG1Y>31*frE;@@IZ>o(e;Zj> z?e9|{;iNjJ%PE90H<+8@#0rZ>m1KIH8-F>MWR(*4mQ#-8NE`8tU;Dlvc3YBcc}f1T z>I~pc?7q0dsRsVIG2H+UEB-&naq_ntThJBffe#*Yy8@kfLjL~F7>s#1>s&?AEq2(~ zwB@9Jb%ASX8inUP}Wms@rg>5?6kms z8gJj&YtWIt8}V&*+TwrPCh<()@*zpWucP4oeTHsI)7zrr=?lOsHbE)z6K&Q=mQ%2i zHakyHt3#o0Zv^X7c!*K9oXDb3(c`Y4-ow?JGk1V8yh9-^G6u9`Z%r(f)Ob|g#-?#C zpXY0Q2-EIvoW$A(HwFq8B{H%CtCUvDeC3f=Q(B$0eivbw@zH7-h7{2J8kj#kLRx zMZmiMD??>F;D^Sy4Y#FRv^5s($enV1aY@M6d@P$WS^o^%iCat}rl*GW0H{v7?I<%a zBdzVCscD|CUuQoW0@!BXY&|F!D~IhaZ+o3HJ2*UGpwd|Q;T_g9reUeo-5VTz`Uq}O z`xH{TTp&G~={UKODDXRK+I`f5Ov8?K zvC+@vA^j$4`crzt{;~EJhj0hsz)AALwr3m;L!L(1qWUHH;%q;!95_qbJUvTatsy1C0eg--O*%-7LWvynX zj%@YDcrFdZV&00)KaCCMiditunDg=vHi*p>yYf)?lzm6FzZ`0@AT`Ww;j4Ivuc!@6 zn%NYv2)D~}$OVo>N%e)@dOcuHN2ue9?3FFj$JU(sQGV*Ex{kpiK7-)3)ELvyV~TSk zuSBPM8WZcp>(x=(a}(mIPqSS2F(d9aA@w;!V-tQm9tzn1lEJz=By={_l_lz`9K!bRU0YM)^u)jt}Sb$+_2w4WyxLux-!O_aK!5I;%~K0E3r%c%Ll7&)RJ7Rt}M5b$TT+J7z$ zcY^O;u2B%XFuVROX~+8Or;L*NSA)lT-~-A;2gwDWT(11S)5)myEx5AMz6#PR_}p9P zs&etk(O!Ky2gC6hmui-eDaMU0-y0)OPC!X@4+e#bt$8|HsWkE$kf)Ww3nYG8g)R1<3M5L z@q$)UkgT6Hy?~7E|5I_|Jpe5EW|guN+KXRhc(YIsPfXtm5_?B7XAV408P(nw8yIZf zGQ5|!LS$DkktGs|$NxYQv=jc-LGuKh669-fr(Rt-Fkn0Z!`c4|kdwH8ye43R(XADf zI0gk|!X(zOxD!^Qe1Bc*62*fMS-r}qj8MiKV+N4DwqWaZs1{Xbuxjv}492|k1o^rI z1n>g^iFE4GF9-VpqA-cS^BLgirr5#g>J1};kNXzcde6xu&^3^Scc)6y()4S9i4mr49=dT-y+dd*XPHG=TxkBZSixDiA4? z0!`H9pmSU@$r4Fh&J5kwhtr7Q_iRK)<9}EqZA34S*#`$2ik=+DeJgs6`|x_P45)c@ zgKt^2)jVh36YjbrIIA9v0cj)(;%P+8QXD2wKv;gYf+Hva)dy3qQNmhxpCFE6;+2@l zX8b0=BbIk-DG7>sHfpXJ=Rr47eu#PZjd$sTD%!vxNUp;CM)!JENgU|n>fmnKigjAp z7CRrCeU(x)B`YQjm;Zlaab!WE6Ji<)amwp{kO06=WlzU8(%-?xS0(1>K{QH}?E3{g zJHP@a@lTof>$psy$r=3y6M__=q-$Y!b>L%(QTcu_o;UCne?1t);B^z}W$8ZwfX8So z(@7EGy5lh91tt{p3^$?`52Z7(R_HLse8IgnQY@xPBdtk5R(!rJ7^C=MxF4gP_zZYa1(Jes7+gaS%Qm z8;w48b=_9ZW1b`^U~?9$b?3hM(r0a)YF6lwef@hd=L^_YFI-&q>aGA@p$X_pFuE1R zB+8mC2|#s4v+9Crm@l%(YfpNlSa#@qE518-V5jM7ASGoSp2;E*5GDFa^ttXbpvf zw9VM=GZT*=VdMB>%ql}7#l+(W!2{@;xeqV@5mL<%3bl1(SDZW%Hjp(H)`o-j>(xB< zf6_}jx^Nf3=OV-80Fuvh@2S`jY>9TuMGtwHwv(Wy3K)EBxE2g_oHk&Wze*NKn0(q8 z0>GtW)Fa1}TfdOz@t~y4zLBD|LL*P=nn94}A?#qtzZHT34?AdXLT$GqBincwnAs5S zd33~DR&ai$MYGB+qIJmCHoz5t?69mMCht#IpwVDKyFQQ>qP1Whgsl>+YmH}sWPhGI ziN5`*>`{86T$g{?aHbf<7~m|unIO&bXJ7LL5Y99Z?B>6EL8FAv<)z7-H`ZeW2e3JXJKj$Z$0a#(k(dy~ zj8$%Hc?Zi(juQh^*A=!#t$ZQkIGfNKomlO*KY-)6N}kcYs0vVyD6%_^i?u0}4hG=0 zf}jedB?JQ6;9#~+4g}6~jo!kN(0PGsud3DYKDZ)#gZX3#uRcJxMSmXL1D%82N`By< ztnPAJ4z$(!+_6M3os+6_BL1sTdoo4HO5Nixm*pb5AO;&3J2O9sLd`zPYMRl4HqQ?d z4ZzV2;y5BdAUU8Xf&VpGoX(>rDiZVqTZZGq6%pY6wIVaLG#KWM{GCFB8*)(C^LqWW zi^R1mn9!5ldHq+;C(>(ZHTV2DSAXW@Y1;}h5fa+~^cJ{K>C1THucs^6Dnuc)NHq5S*%(2NL*-=9G^vkTlgw>!4FMtq7@`7l{wY4& zh&ar~wV`eT9`-~7vfGf}MzMJb;)VCWrq^?Z9$A?wCvy2IecaLij~?M!(8k4K-^qJq z8c}=xN7DVG(A{qlE=js}B!fxcle3 zENJ|V&i>oxq^|DgdPhf%ZD={iMF_EK_b*YC;2;;!rku9#z_rv!y_hUf^5 z^DCalg678QiIH!!a*Kd`AvC3&Kt-r^V zNFYdsvCs7=riff7^_Y@gG> z!ANtf)<=@$t3>l-6}&OO2fTcW2!!dgc-;tz+RGu@lk+H&^!tFVLd(&`!Fyw)L(vUB)t3pOLx}_?ps-;WC4(&z2X7n)-~RrWAsg~b(&{H z9&m3Y>Pql;lkb#OdC_GLim4Y(Fg zo62*^uKgO{U7@vlwU+hrA_!ay=X_vR$E8@Udq?10DGZ`+Pc()DU(=e zaiFT-EcD!7P)O{r3`|K@S!GJ_CsIGHx++0t)y+h7E%*p*PI?ksUhi9yO2GUSa!ul+ z@hBge3s_J_Tpum_iiw3f^nKO;bwnUKmT+cHDUIRj)8}92&(l!96(CM&r1XCao`<{( zFu)#_ai#aPPx-$K_1pR1EjN0QGC8051E%fQA9vC23y{|RdeMID$DywxNRet*dpbfX zNaOo!V%I*Lgjl)Ro~#9m;)q$+KJJQn9h*dL{_4pLeN7Zf(4G_XtV(G>_qaXgjThw; zAY1+#MR}5oxLlI`NppZ~YotrUrcjQKj&IDH4%sU|pDA?^%=qzwtaZQ3VxgB#1@!LV zc-W)Ju5PW1xBSMw%tkz7Z{1+p`adWd@mOpzGV(;^rB$~latycr+%>S8c|XJTn^{GQ z+K4>r{N-XyOs-)C#tlsk`>?a`0h<81^(sRxl!zIn&8fevN(GRf7TGTxD}r~Zo-CO| z&HczFX#sQ1f1q?>&!VWQ1Ye2Nz-8E|18tKPGoE+ ze*vyb48C^|R zkN}9g4G(aP_F;h@lc5d})pa%h8R$sM#kE37#RoEA^u*W~_i;KPejc&SFAk$aeQsqI zmK1hP;kL$W&$6NZ$C(+#CRFYSQ0~9`wDeRZfJZ3=Zh^?7gd@iFm?9y=^rB=JAjtfO zW%Ry3;%{6_JlDXKB=yDQc|}Db3j8xdfM(bmP`Z;5O~T44FLJpkBo@=hS@jZK%9Hzf z*FW&LPbnnmp46g*j3?-p!B5KAweT?apszj+@+n{Yxu_-F=WovL<5(g0*4p*7FL!eX z^xo6PYi5rYEhCelZ#E0kT-yj)c^&LV@M9;Mx(B#)%os2P`c<1U6ib)=QzQcjxb z3=J@o;iQ%m(05+Ce}AW&Fy;UNO#Suo&yEsX*t8XBG2e69IDF#hy4mEEg60gBjfTS0K;%^oTXu~zN|Bttth;f0r zti9*9puJqh=c%r)qA#F{!7q>{q?xNqCQIQw28r#lGu|E-8y3k>ZV6*Z!IY|P@$Ble=IJEq;E^SLvRMN2E4861sCa9nlJN{ zI2O!gQO6dpf~*9$5lpO`EBvMVUSJj2jE}1>Vo}T;RQhVq&7Fccwtn4KwID!I^X*aH zsC^4JnE$Y3nWSfi)x>>(^@=fGPwk_HAv@{9aY_`)9QSJ%y;Yb+ALPiw#JJG4Tc9

    U)^k4>L7z9=M!vHq+?jpf|{i_aEPkeen*~@{m-YT^6X+E>R8+IePY~qtY|k zP;LL4${<-OlE>%Pmc}6wDXn@T9?3e2k?Na>v%uV>T6)6czXAwei)f#V;k`a<gEJ z?W@2x} zd};%7Ud_nJE*jE&MqDfXz?DCJLduYop1+$jGAXb>yF+VWOn%9E-G%o#QqKz7&v6=5 zYc#A6RIt{4dHt{|UTIFz_=V!XRg(|O z-uHM{cLY`M`t!m*yfFT3+~b|Q&)nMaoe6hbN^#E{J?S&2uDhnkKcp|{)4dcOeq6>j zRO^y`c1oAeg*vhJ+1Kfptoz@LzYG#<7jHjWbTe=~?Z~H}!bOYRi=rcy@O*Ks;*9hf zb5D%*kF4i2WG(=I5S0{f`o=I)LHod4RbnJryB{OH;Fab>^C-=ig%g=1IbA$0y(YFt z=k4=+#qe}# z5WOf5ch>BGr~7W(Qzx^uPPf23We$$>HPF2(b`0-X>{^qZjRmj3k(UtwCu}Hb8%6Klz0Q_QEKXex8oN?TOaIp7*D&q z^MZoJa8Flz?D}Q)89a&7rwloW0bYc6=YdVfZA784-7Y=M9D<5Lsd0n1g~)F}l)mq> zd^pdt>qsA_f%{LO`Hn?OBt_D<+oVgDN5y)2?yG%E{e3yZJAXe)_k$yexcg5omy`>| zJKqhWM(=xYNYUU>p?PYv`e`bv4C{|$uV-m{qI2=&gJa!l10F@0T;F2aXgQK z8mNOz?U#>tLatlG>{f=K%MS!E`5sA05Fs>)G0-+f3#aCfp2qx{>pWTNho;b=Dq5OL2s{rW9PFsZA}&NriG5(L7-I?yzO8bPuT{R4|yU! z{DQPeAIjOPT(;3SnyyJd9dm)}maUfTcIVJT>D!Hbe%gjRiP$XZ>s^N?tXYT*iw`t! zIB`coK}GWIW6BPUm&M5KF_I5GH6p4~Hw<0~&U--ujNdm|t6oN$8iuZ9cwvm)YjJwy@z1 z!-#RjWINTCL4Ql5`lm<^!v6-+69F-iPvsg~1YBVSV*;;S&@6_F)9-?g86l)Uz& zMPK#1P#JNh@&I_aPsIHL@nW}6tIYd_j`qXAU-)48d1-5GC( z%f#D8XFaCseIK`y49L#{zbCmnj3qs7qv_Xi$&q(n3dLK`@1CMI{g(7P68MoBUMh6zQB{4q=-Pk38$M?Mw5qWO%feZ8CcdLir6BVf{1Chme-KNnm<7Rc2 zj?9*hzR8`esYSiWCi>!rvA+*-wmYFokE&U?S8PzVEJSL~7AoFUY*vtNt)uuK-; zzauM{no+YGON#$?F|qKRHurq!BtArN7piS}O18w|RA_*H1iY7sks?LGk&#bVJw~-0 zZ`pjyp~Ai2BD8yTArch_xgyQL^g1`$#;=35rTn6G#TFS8Oz8Z)?G)<;%m4CdnW8P& zBRsW;G%RIpvBcl4D7F6TOzd!dO-ybJdSjR5@mdDIWFUudE&?k%mW3#y=!kpzat)fB z+n41v1fHd{DF0l_$pF4u^#4^{U~3z0jaQX3^t{ozRCRWCOX6}sVl2=|^?UU&*@-LU zAk>QYSDjnBU_u+K(6anSdNG2ED1ro7f9_WZIy*PnFx!+c6wsfQ;f}7Wi6+!X8HzQ*VQ zo0RAih&9#L&y0LhQzIDH7S9)pFgs|(m-n4eZa-rEm|?^u8v}xlSnu26^uvQMwUj@S zwF*B~Ui7u^jZS$yUh~_2LlwEv`wX7hv&)~_msRNC(xsDSi=$#sK+YM8-BFn`}JAHgI*ST9Gz#(DwL%gR^Mb?ArF^JWsALY2ZMK*O2mG) z{`A2~&a+x6W;)`wLRbXHeoE$UT!ir?yx~tu{Zs&B))y|P;Y9F*i{d5DJTKyU_v1-&KTNVC@@*tGk+eB~= z#sNmYk`nbIMj?#0Nl)+kQ~{5i@yVc&a?!|zEVwg~Sm(wjQ8ervaWd)rp;9Z}xWQ7&v@rKOVN#UFV1PNV0v-Ag)mR#I+2 z@g^!q3U^I3bP8EY)6zEHIdS|=euTj$$9=79ZA8)Ic~;&0&e6e1Z#`M6wD@NK?0q&_ z?e*dL;I8%Zt zc<3ECH_{VyU}B`l-#qJJoz(~LnClt0{~rGI*g#oIKH}!+u*NsFu?zq@u*Su;M`DQoV zl6DtNd zdaec65Db4@e))&6W_I0lHoQ}Hj`S_$o8&+18bu!#y*yHn;K4N8*=F4rXS^lHurd3> z#%D|LB^@2D@?9@Rc;lRUPjC0$es*}y=ntFG{)GKm4VRAnCmFUu#V~VGz5QHF<^o7z zhU_iKYG??qO^fMGWp}5y8#A+)`@({z<3%GlCz|4OE!T4h#+f$+`Up-i7v}e& zXO303(qm(978*YpdDh=BsBFGdoV`&@?jm=HARiFn&EffyjEcQ*Ha{;W=4Z*^=0q zLftd5X*~%O^v>&55xw@mTB}l9UgiC2V$Q^Brx7ty+dtQ+ODnI{8iESamD8tGlu$JDfKdVb>c@}yE* zKI_f!8^5};fqJ(Co1REv7`LUT!6~Dj>4p3T_N)e_zgniTmK(YD+MgE~jqX6IfiL+E zsj&{X#UF$CQehQ&yC_f3aEdJ}ez$Zy$r4|39)*6O+*W37a8%%2IEvY>s@E}SXBybLZ?^sDZX4@15V9|AB&Ia#>|GM>7{gtZR)SURH;}bBA9t(tm##K2^hS%JYrbi>$ob@b zl_H^UOfJf76>^Gnr(#|J(7eqKuW822>U@>igdtZ8Py1PPro}uaq{TaScOg4Key%*$8Jp2 z+K8wT3cB_I z>Ab^U$H7?1psW$KKn3UI?sK6pE*QjOj{*fEBA>9RMjc#@%g=@`4RysAF!Feh7-sgZ z2q_fCeXt(j@!#L@|Fl_*i5^a-Z8&VxIr8J8@D1BB_s$UmwSWr^Rn}qwb&@;$SWq9h zF<({>XGF;$-gP60rv723&D#NM|rKq=gw<$}yrW9LM3rxON z(AHLsy#>)`tI+ZjiHlUgLev6M9z$!u?IP+bSoqM7;lYCkhd$WakN23NUuZq3FU%hR zfM22=dPp#NN?6o6-i#*WCap7top7@#F7NxGgy40cch;e=3KeM|Jos{V;7S}m1?{`r z%Fj3L&|FZcpI3fs$zPA~$|>~d3m<^NU3&my_&-RQT@fd_ZeX32n_JQwYFXK~Jy!O~ z=p?2z|AJmo);pglnscQwp43(=Q{p%?;8xRpOUzzVgY)G5qrY?V;$=tGqe(NZJwk;) z8GA$StVPnmC>PwKqLH+ia$`tFAw+=fakM*{zKxNnV76H^URDL?MPFqm>D;+3rXr=K z-k_-gC`q$*p^uViZz5?fohzDhu?HBMpk~d+I{?9B>fhn5q9P1xVsH7?si&aGuK6|( z=HhSVwoV(|7(uuA2BW#q>9h!&OHG7y9V$VyWO^X8mVTRydj3@WygwjgCh!iKz>SXb zhR+n!_S9)k)77B&-RK-hLq9j3mvWS!OFKo}Cj#x|vSP%6g)n9Oc~bkCbTvITqPx!V zmPlYtp@Cs&0NDkReZ@sBLeHYM60c0oT2yxf%e(pTIHM<_UUQHaD|zzRq1a++l6K-+ zQP)@a9{I;7;rpV03Vs%v+9Oobq5r4nD`8;+j?;BB!`o_XEnwrd$0UOy$blBA z^0sH;wpQWoom%mtI!<-XNA_U-e9@`|tscSxze31kU|;ylJRDu>q=Zg^5}h1LV6<*U z?eWy|*Sq$osRs9(f@n#LUEuL-sx{vg{)t)Ah6?MUU>O%yM+AE6o*?P#Bj1k=Ap9%; zgZORP@_P#BOwK+YwPYBr>($>U@M>zKRKhEVA1y&=l+wG8Q~3SJt;MS@mi;dX)w}F3 zyVYir2nT!P;h@kR+wR++E;)0RP4DlvuskK(xhph?GgryFW*Zpz;`7E{I+NX0D!;FL z5FhGB5{M&`)oPAF|NW?wy?cV~xWk%Q>tMiOT9I-^3g<68LECpdGmwu>qu4CzAZn|{ zh98m+Q+lP7+_3S8!?WL#+MPUeN0cSnMUA8ONyw`wy*;dgDP{0pS&T5l(_`dl_sgv! z-SDkou{YGNMo%hnp5#$XMAg3$eryCxOJPaxtLDJc$0~Iql}qX$ne>LMpu}AWN-SYB zsH$W5p3y(4-;$PARTx$!bJjC7gsGmb%`3bGuOM@k*c$$X?)AcM&wp6tC+3$7x{-qV zz(MG?ajRiAq=jpy=QYqa$UHl5$@AHPG1lT`OSU>|tWIb|;LYSZy-Koe5d8rr4Ar+B zCoS9oa}TuM?Pu#YyKxkAlJS$Yv_p8sw(V13L-_tz=L>!^n)(+AAVwrDLw#=f2)!j~ z2+69^P`R`Fw*5u2Y)L3VC#6@xl4kCxP7rU;8z`>IoMosjV-*zq!?*u}!q)LiPIs3b zNDYp5@mT7_7Olb!@qGX41~-lGE7qnlKX}g%bM(_CG8do>NX;&Kk;;EaSGA%yCM#LT zCE!!l4)%(oF*JKpZ#ImE+3v`+a;Q#1U`u>m#Pd;=+GDX!Y#R_*SVhKvWOb63q8>TE z;98gHry>>D6n(}im5vL+2V8AOnupe)N7#dWqCHLJ>&%#Ix|5@ANRLUlLC4NdzTyfB zN{|DYs)n=$zGOiZJ*-7QlSXj6u4I;QKxPP@bs)sib?i#En@r_OlA}-u(vB%I!}{Da z&Xjf>VB%|(;kmAnNRN&(M9R=w_Ubf628QIs4 zd>2~w<-@y|I!LnBLF)xCRPohb`Wf*C(14HOmf`nMeSb{isBZ{MlHgm9OT+G6i=${6 zBV3eL-_=NFAb&gq?e8iF`&}Anew>i$)@4W6vte)WadRJ+c1loz?AY>{20DyH1NDQ5 zd@>bFJLl!IEwT~Oj6r4mKwV6KrcJ!gEY}J92}=r=-A#q?ngUWGrc4fTMEMi0Bk_f= zWA-JBA)dZIx25xHyOi(OvZCaGTI9dlbnN0L(QY<~x<|%KH-> zjfqYw=`(pEJ<)L7P7#Dw7GAL+Y=_g?E(30w`a^a)lcPJbeINmfHk?*-jGY!vWGI?8yU=RJInBM%AFq)2GvO=yqmWWyDEF9?+r^j0p;S;l2#Loct5>dPU0` zQ2`ec^nCK|3uy;85&w&`mWdO-{Wcmjf_oB>P9xkbeaF%qe!eU`@9Iga@ep|14HFF7 zK}~AoJSpF;PTIBC_MdT{+Tko*2k|54zW#gs`f%?9=%QE$+qxg+9-|X$`GD#DV~H^k z8X4FDelJ-oS$l9!aEno`i_Rza#@UX?`1mH^@(m0*UpPhGJj&4G>D7NOAQ`;gvBMVg zCqlWmD6SmfNH*`=4ui!C^FaVeg} z;aG&0AG4m%Y;Z|Xh|^-Y^~UBXZT zORxMNtL6oSsP+$=(e@eUj(vzX9&GzC6xpaFdeiz{jGgnBdkayP^x?lzjeN-pIS4_~ z*G7-|pls`svs+SwByXC;7k1?EVxgRIF)lz%_?d%S@)}vP5eJ`>Uf!Zh-jR73?m{9( z<~bD&Dck{9wn!o?S!LtcpA@gc`zMBX$!4ww5#`}JNoD2Z#dHYbU2z4j5(rSC!xq0XrCP9#J^e7LUv=VSm^dVXaTip+#s(kTQMLMgI~ zQ2vg+C0OpGy1%$^zhMjAL{YJRWMl6n+Jd%-7rQH*iG3&4p8pM)R$O(19=4Y~+fMF( zxT&etdw9!zh^D7>pi2lLhr9R1UXbx;9dN%`MiBzzxG*bIm@t1Ida6cppI3BI1O1ja zd$NB*nwO!8HTss?$Thx3);F*BNMnR(9#9YlL`C?O&dqFQ|GoyDt^b-6#m539gicDI zdHs(3!rM=zAm38bfoQGPUmx&UB(l->(`$SOIa**T-T7B{n>jrXC7m*}%t1 z03Z4Ok6A;_NHOX4*V2^EhTuEN|2Dgvf$`=hglC2B5gjoP zkGn--P;)NnUgkq^=XYTqtzdip2$8rY`{uGSKYu?UkV{71KQ6qwFx--3f8(dc?-PU1vaA!A1k|M#pT>-t z9WHXbr_YefwerM|)7+8Y$Iflg0_{^|1Z4!_>Z zI67ePQs{Kkd15^YAhj_2K=H&^!&+wPmp|Q%*b>=jz4`COaShVW)FTV4olj^e7udZE zQ2EZ!fl53$yjiWTSde|-k;DV8YGv3i_pqSg!kV8fru!}{U{vUy9w3FH`HL)dw`3jW z<$cJ|c*)8wTqd{(VqdzHHAhY@-d-&ox@U8k&zTBtY|;^T5YV8DeRz(tOyI~|yu90< z;^`}(Guk}!l|+OxD${k~^FAks@iz(bXK7=7@@U{yXZtXct1BzA@QX3xT0CYV(Qhyt=RQTSAjco9_e`*cy3%6REh^Z8;w016~g#{JdIGSc_8dF!3rRV_r}sIBA_a zfAD=_$@4w~vj?(SkrQ>oe0;?`uH%G}e=f4*90O?v1ip}KfVY`@Wz41>vd*2}Z?xxu z-({f7Udi7r(u)`)sRChuo_$H`0Cz@Y5nk0&Pv!zBsxz3!z+Ia@Vl#c-z30YoR(I+3 z6WC2z13w?(+N4f$unFvx^!wT{0XND_;?kI3X(fk^r2mJShDHX_=81&yLP z@AZO0_|w^gUIj>EbFM?o)2T38gmFtqhXwnE55&w9=gGoEs6FFY6q6|JI?$JH>RfxY z?Y3?FX0%cZl-k)Y30CA@$NFsy(h|2___9zOYTc?g5J3tf`|O*-@mVBCK;NXdN6fAh zYm<7-8SHm&yUc#(9Dhk{LdPR9)LzsAQ2(7@HiR(d>9jUb_a@HlBzB^FV2E2qE#t*L zr(Wgl5c+En909H&;Gyw?Q{dojY+CpW=>ugri=zYM6={+eZ5?F;v8s)2tAR}t*-`5# zM{7Q#8?ayXA|Y|P=KeBl;EqGLXgya3&DAQFd4ok?|L4v?GH2>O1S%%*tLEUXeri|i zjZ&;9D%)lM$`~2*7->*0VKcAc*aMb;tj;8nIcIA6V~f!g&e2ru10MTR+nuRiS911> zO(c>>Y^7d_MW=jaM@ zovVHVA7t(}SYy+*#N$AolTGm)NDx?*tH~7vZ%hxJcWj~F?P?1~#jyScBvn?>&2Xnu zlW&i!8sE=9dI#$wN`UK&F!A832NzNO26+IHXqWpGQ2*ENkt;3`LUi|=nXY$#FrPsZ zt&5lY`*b>EbnAky=$3W7F$mO&s&~TZzHL2$WzIU5Jay5%*F=EWD&snELT1;3G6Tp= zr^7CN?dD>9gC=w^<4?|O!lke%g&^@KUV-BEe^mLKrN*S0&uoEa&q^8E{#wz(S70p> z2ORve19vbq>1{?lQPl%f+i#B_U^0>5q4j-O`ARVh!Y}g34PUDcO7nqsoO%hoO0|vM z-@J~%*{dJWeR}LU;`sI!?s=H2||Odq$`Qc#Pkzw>c^$v41TyE2HlM56i=Fy+;kJDs}R{p4eF<*@DjPLklWpCnr> z2j%MjxcF15DcMcIt)IhwpI_%pYw|#vdLV>FLV}?w;Mr0C7_(uXnh^V8>(G{CDtOb$ z#``}WC-%B(li2JJ0+{I)m6U`nJps$sa^Ec4TYS7RaR0`z@rJ4Q`aD`Ru_xII3mUSF zh_-8^-S2k{f+JH+?C>9ampYpLI_OhvXx5XapzkG1ZT)t^9b|Xz*x_@^&^=-p1uq2bN^ zF*7vj8dQR^NP`P_qugkDIY_tZn)WWV;7i}3Iky^Rw(dAQAXELVgG`3B4jmGlXX>a2 zfzAPbR$Su9jCxC7{am^0z#~sMx`ACjJa@T*BZSXeLB%5{sWMJ%^`)-R<;t-Y)wOhN zx=mzuB(-P+b8Jrw?o^2AE>KI=GWmYwh_!ygV9HP3=(|kykgCo1*OS!kGLsO+OI?Zb zr4R?BI|R<)47wJsV%f&dEwSqcLaNsaa-L<*@@U}(aCJ2$M7jGfSKGmXG?rZ%6^{`8 zG2GwyADvu0A4#?*e|1bLY}ep1kx_cukMvlLo0WGry*lz$flu{AK~T!hpg@ysIdrrK z6Z*OdbB-}Tz`=0Hdh;$p=hvco#|@%nDLf~U3tPx22wX^yM@g;Flke7S$r!#PQ+NRN4;ZWD`mm!G%R&h=KMgq#KR092X$prDP281hFij_-x8>COU6*hhP`>Ix6#MqmOB?$Md9>K&!)eHSjXMrx z7DuA!6>)!bBk(ZIb}ogF%hOc#_2WK?s$zeeI>!Z(%LW zeL{``4@5$5^E)7eT5Lij?FzYiuucil_XhWargk`JW&ozOzq$$6xrawiG_onz9a|^P z#rctAM?*E^k6cOXsst{y>tOh%K$CSt>2&2X{;8sJcTcC1V%54!gjC8cr*OdwKPLNP z#y}u-`xN--nsU!2ln*GoziZx1mBF2>=2t1pB&dss%q(T{;~f@UIyZA1A6_s<#=pl7 z&W70hr!wmE-`CKQ6pQ^L?EG;@;XIk(T}-<&)HlRhr2Z5Ud zoO+<|*z}oiVP4kEd9_ZB&%JszIZhD-RC7JzbY?YOdz&@0Vv^3!KEFyzY%tO^j4K}M zBf0oEC?2PrX>uW8=ub5Y$ZkEJL0hl)GUJH%WWo@fZsL)YIX0C3Pr|;!N%C&8T+Yy! z+2jljpGNt0mEqrZs(z4X&~oiXWl>VH{gAV)N?o@<{Sut0nbQtwYXM zJ#6rbhq~^mVu4$>O3kTC<)hh13M)C8vtK2G32ns=g17i?oaW%r{`MADrU`GtsQSL( z&mNL#;dU`W!@*K_PT_Yz!7%s1^=#g_pL(52FgMmO`axyky%v?k?7^(@a3FciYQ<4l zr5cgzmIpaMBFyYeRn)B%Wts8cX0oJm*^v&2-?wMo9(YUNmv$Z6?dsPn|D69#794W~ zsqOXU-_K0I_Ydqr8nX&~`xN%vOPL{LwY|mGZqj~?C{nq}hd6q)V{`NM@DB9+lMn_N zp03)spbowJDGyH3;F_3IyDMgRB+qk9(Kh>fPS1=x3Dph%c-huU+bClc=res3vP;Ll zmv>cM2j!ky=^(@@^S`xVmPS^0j%Qp}r!m)WKCwARe(k$$%LgBloe1T2z7-T=6lNiC zxTRcsK%LHXuY929Fuwi%$xasb;&~}My+HwzwX{+rwsRyqdCY}*!z&6;I!ugPt%)P{|>uS2< zI#3hshBQF9C#%y4p1{U`*7gYziF~7VKImW!)+X8e>K#HJ9d`BhI+9abCj$eYu($s@ z*}?^>=@r%t?tt(=c4e3x&N!BDt`NT)jkn+JGX9UZ^KMAmuUAGz7&Cw5FxA)~_!aS{ zpOe+~%mAz%(!CJww5~p+$ajpb#r@#vGdB(&&+0h22+xJAQ0rC((Kuyr4gfM}8H3O4Vr(k-32IBMP3; zstNSy&!)^a9*wn-@Fpd>CU?h>H>J+UXuaA*DYv0gxx`KMYq~+Aa z{rpl$%v8&rsXNo@nptVub{DgaC%biPu4fsSt#xqhzDe?;QPD2-t7s-EbTK|h5KZPk zybtpKSSJVU8=uC8e&}a6JmyXInq}VQ1G}wkVFAiMMG_-7+_Su+|8nu$APJ)K@px3O zH}%@1yZUCg6X*+zsb=@F09*P#>-6lpQT8Iuv4+0tB@JBQ#}lO#j!uPYq0tKZ=I)ffAW0&Dxo- zK3VhfD=lLx0a0%mJ;Gg6?1HNNhkhhEZ8%kgivijx{YuF_=_Kmi_r?)^6Dc=)ZNbmw zo-UP|c_(!y5_S|aCw|HOs?hG*d80UeH@ZTc;fKz|riOlsq&rm zVA6`+m6p1nGu>mjWc?w5s9b|9@7l-3#`11V?@>I8OZwTb-um#I;TF25D}lB7U4>)$ zAi&W=b~fB$=)cu*9Y~)!U>fCCdC;|z-7+~xr-vh=GjRW=uxWa%-ph1Dt&5*g_Li2G zJhyDLbe(Ue_jv8_Xg%B_^|VcPdu7NgL#49r&t1E%*mk9I2iSFlpO28N(TP>iHnFWf zyfkpe*`VxxO$J-Rw1&#VzHZcSd6*)btfw?;Xqy~+rX%b$HwbFY2yo_f%lySK)Y|_k z$2LC5ldHbREgFuVI+Wq9z3e%Bbd_Nh(f+~cZVA2inV!g= zZC!L=(+xAhnanMVZd#fvjCy%8PjrYKXCXfA7x1LJs0GHK3HAQ|bcr`QHw09_6nwn0 zHzT2_s2&PiJaQEx+1~aEjp!(7rmUu|FZ>ezEH6t^__6*_m=arMXm6|e)ceGiKI77k zpI^(&I0|)O7>G#!8$cqF)sty*bubxY@5xRP>eL_m_Cv!^@npk&`tyYs0kan0b_rk5 zbkc`KrV2qm#fT!!Y}m$$+rR{#5g_wSy`QsG9rPhy#`blDwaClv+?5_u^hw`ytgrsk z_BM~H-RDFr4gZ*;tt|ZVC4BOnMCQ=T3uj!0tu0KGznPhdT(ST5 zVQBO53~r=-WN2$_#mCWaBD*qDRrxe{VmTq@WT=5v_sJ(a(RN(*r2gzQHFzX@PEm0o zaX~)(`usS-9ghE`>=qGah9mA!A(BjDXsR3j_>$##h9Zwvu0TvcAP*WceR~UuF}=$f z>V;-D`moHeB5GNYMPCkiafHQt$LYX03@`sFS^wAm0siOTSy~Ri&&-nyj4vtSab>IY znDn|dpvsXG=||lg(jQ{*b|Jlco$fNRGLFWmg9~g|xFg)p(nR|g&C-B@?5!Rd55fdAa| z94Xq(m<$aIqp!w$uwBdmc{8^CW*84B$!D49!Yb^pi`bZI=0zgB=yfu?#@j_ey@# z_=~Z8(GR$35*k-TFYNokd0w{D%9DfEff`455y_JrL81H{ZIHp&*wzK=55$;+;a__z zm*t1*PRfZ%SSN4G@qhO0Y&;50-5p}u^KSF#yLXnJ-AekU-(e<9VPUKfn`bQ%-sCf( zzh^;i1E@k+r=HvBasj8{Mc=;N_e32BPQA6IZHgYV>s+JF1*gn@zy`?vJ?G(VbG>A2vPIXCH(+XPg5_p6xZ7{ zE);&$g{gKR95H1OBEtxF5k!3USM%gSe9gj%z|T9&?74qwC?pX-fq*~TqY{H7l_eAX z4x*lL5Jg_8CG975FD%J-DC88AprnjenMkCE!gu|AW`f6~09(AbBHAIuCh;y8T7$M9 z>iKW{qw6oTW|~bFBX4z+U`YLC@S$9uG%|+C>(q9}beBQX2Snb2`J>jSQfPrvtgVIf z7RR7&>)84xw6nH7TjhHZoGJt()}1i}Ge0CpJQ&(`4HGpTn*7+!^*)?R&PxNvWc&z*s0 zmBjcO9oIex*Kn-yWrocRtBwb7ULnSHWv#cdp*vJ_k*d#gN$l=`?5PPzpN}Z$nDEi{ zT7Qe)M-5Yx%$*x`m3-G5{&o(74iI+Mxv!XU$62d>TKL`>xMH0_>;xcOh-}RD1le!S z*K55Wk^SVWb35MX+O=022{y#F>FXQ#^EKrT_Gq_Y%>;4*6xUvu$Ix}7V0cI)PI`6v zh6WkkQ7SkXn03o{>Pini&IWaRpw>)OlA`@QD#&I3ubrP8^j@2*H@vXe!p?*X3Dzl- z5fYSC(b}5hfn>VRWuXp`Xqpbtjwny&1Ud1{;H732Is>J|eEKFnNL`pqR+1bVRCpEjLQJ?i_1(0t zcK=~lwwGBI_7%LYYLY`qnSaCCo3^!G9DkrpHs+AuBSW&i75lG00ALXPVqR25$M{g% zwJ(H)rf;{8+MzOaD#Q`9>S!v9{`J=*Bz2b;NPKY(w7FSgE}Bgx>cKeS!|4*lZ%B;& zd~7t-BhPO;;#4C~3ru&AeMb|^_(zLbBfbqyo=+TdL>S(OZKz5pYBl@8E`q7qNjxR? z+h!YFAOtnL=8~stKqsnsvC*dPCvIx^tJCxYL-MfRxqm${F}r>IbPD=U%ce6l$#nF6 z3N3oYHanebWp-I6!$E2WPDtuFJWAHlWlr@N(5VQ2?dUpS^WD*OK4lX|imo=bhM|QO zxgtiNU@TbN4~NVP6?WyvUS^HjJ-$8u!puoam}3NQ#d&0k%mr|Y^Y5WJ)EJzO(*ISd z`>qbCW>rJ*;1>4HxCG#In@$f1Yt_2g52NN6se2JVEE-U{GBS>*D5ljdiEVE?`E$EDMIH5 z&-aKvRK7)4NwRlaAm9;outSEOJo~Z%rW~CLuH(hZ%;@1VXj^5t1m#=@Lmu?}aTd|i zZ>HUvY*ZN)iaIaa)v^jUao4mL6?pA=z_yp6rF^Surbl{N)pdq;z=K8aWi?`%(BhF8 zcE?EfS&z1ol=U6UyJ6YYsn#xlnfYZw$>a{wf8NCZd@XJla|fIrh%h?EF8Fy$YUcfj z^GvDpi=8i=4mkFfh?c;#j|(X=U6Vjo`PBz)DoYt zHP_WbER4R&E#_?@Q(K{2W|LKz8zS$YXl8lNUcWh~?dfZcrL6$C6I8y7h=mBQ2ywJg z;x3{Ggb%@cG}p;BzRR{0LO}MJ`0e$DQ!)LeRc+g6zN{b0vgyRS4{>H$nm^AN3bPuu z#`|B!(|(ir*9^+!p2>A{1-EQ1%&T>vnc(}$O@2CPWlqyQQ7Nwy@{x96Xr`z>U__(k z+OYcp8{loej*y%3BcvfeVOyzD`R?CqMEr{%L~8=@O}#AjIQ2OUr_vWCZaF_r_ zMD>`ad91-n(qE%6l8;GdI}iK0&utl(bR1GoUu11_Z^4zELTa525dhe)HN9;G~LI@JL#P-*vr zalH3TsQ!(<_nrC55tn-U{>XM5{O8*k)T27B9IwX@EG)WHa+0~pt)SnQNZ|h$YxaCv?}jG zrN?A1$e7>~|7#cl=A8s2FT)xs&U4ZMOocrht}GE zOODcUo$7NQ{8pBJ1zjoo2(n${lHH!)N`^-AChH+OJla-W002d z7tuK+=o|gFXsHY~w)*({pae;V=WlmQ2j4=4ui7Pw;ExAj93OsGJh-`sM%>tZS1zv#djNjGr+f}iTp%4ZMi+?^?*D97Y^pQ ztzb08Y=E4zgt}aT(BAUz`a#tMUNm(tsmO3RVrx5d2EfF^&jNF&q&Sbwy}6k97VX*_ z@&oTa%sr9r#7U$*#9R^YeVLTGpRmr?QBsffoATJHSZ1Jf%o^bkp14#~;%7wv$mzuV z7B9lVDq7qw2Yu?nkf>{yGSo8#!X3a*l*l0-ElXU;@lT}AuY1Yv4><0C-Vp1@gzl8K zPc8QL{R~QPhYbn;aG+6V({Qb@4Vl=(y<`)jpwWpwm&6*w05|9gmwmGbysZuw1Ch`; z6VHqG@gLU-3Sse-P))Z-D(MIAsg>>1Wt%zmsWNFNTlB5UntjeY!E=HHcE;+_3ZgDf z6xW>#%CVe~oHBZeymgt0{Tx*&1=Ju~+`^$t;@{km!r#+|!1yg(al3J}pJdnt<+kQ^ zhlG~w`=&ec*q3(9_pl@!oWx($O=;EOq4&J2&VGDXPf~J6y4(C4BmzBKjW27&2m-74MB#-3b$fgMg6RIFyqP&Y` z346(cpTFAzLyt&as0Xoz*>#>lYj9b#i40#@CAeyw$80w~amo@oft^r9D{=C@fs6=< zIwB3AI`jZ>ZE=FV>a@gjN<6;iFvU_ID;Rwv(wezZKFw6}i|s+R7P7LF5mbkV%aXw^ z@N=Z1ez^g%>B$Rkp*hn{i7V%CB{5HlnvZX7nd7>v#-4XG`2RsX4aOS_Nu$jD7Ms1C zA3a)VYh}`)e#wKfBO>IzF-hGKrAf3cGWUhhKE3N>&ML{~w@YboOaZcj`tV+@===d4 z%_;{s8>)~HJX%uiyL_$t-_N#!M0oiySBNS|cEXgcb@pNS6T<+QEGc4iA=#dkuH=!!S)1A85wfb>Y|(+A5j*aXLOweDV9f zr84Be8|~b)fIemJE233Pa#XdBKaHvA5S%;<3p^jB0g9`x)GRL7H5Yz*wLP1cD9o5R zkML;wqQu}}H9O5`l;uG`^ex+OSt>{E2C>3ew%+IYp2M^)NbQvc%98yKQ#*8`-Pjxx z`5@aYi8-(B@9!_K@hl+VJdl+7CFgwdEXwpMuX2w6C>UEiRnn338H+GIa+;z1k|*t< z55@}Vm;h&2xVc>^#O&!U>?mDOsZMCneKR7vLttWHr>u}6a;1oAZ;^9fzM@7u0WWKE zK_yP}81GwSQx&WF&MbAh*wHMc#vYQ<{b7B{jG*N6ui7_7WY}YqbYk0Kxym-W)Lm`c z*{*WF$4`DaXV>bhRua}<RPK=75 z6b0ZIydG)e&yE^}LlxOkF}3dMM`LS~qkDJXV{5hLoI2eJ@p!A+7h}`P-kV73oLscN z3x<9uM_I8&)@KgMNjT0;HJCpkt$t`cajepA=Po?A2D>1zp);*Mo)M}{3D89HO!Xbo zaSdI}?pTfc(rBOZk=i2rckyMST#PA7@3YOO^|$xO_}kRX`}1-ZA&OZ8M8#gT_qTt- z285*;2@p3xHQFPg%WJFHy)4x&Ia(e2a+rt*{=CM)9dZ{{C~*JvkBj z-J^DPVY70=4nV^kfmZv3ZD0M36YsA#z? z57cQJ=S#de{krd2C!4bCWb;jndOLenB5fB50RXC$PM}GKvwJO379i)U|^}+-QKw{Gb~t(s+X8wkSGNt;*xudXfK|}$d~WChMcod zkCvuxOOPDO*Qk>l`=%hJnq%x}6XnjPSHO|-3hMsm>jt1gQZ8<0>M!LC(;7-ubfqd7 z2{N>OI6?8}I~{(+k8p0R^UoTn@+(1zQz%}Pq!dMPS7$y}F(_^4==GAc7mw8ZCJ9-e z=fNjEM{2#3fo6HY78K{X8H#jwqG3XS-ri`8hmd9z5T-XIv3~u#Z&?hFvJICPc-e0# zk_n0qh*D!0n0xV5xBJM8E$S@DYY&$MwI%s7^3a%aWC!<^z1L*ucoa7Mg3JX35$qUF zsSC)7r&u4(c&%l0UlKY9*jm!&GlGZ6@sSrI@=>mp;@lI%Bf>@DQ`rNCc8yJYI=_@v z>+=|;M~R`Zr?Z_*G#pb{nI{Ws@&>DW1w~I~aL4dM7&Z=3e+)T$83xdF({e#(L$b$7z4qUG5{dvvY^Yg^n_fl2aV z21b9E!KE4G0*JZ@9b&QW=H)LCGz{!#p|g$13Cqa71C_@|D&dSpb;Hr8%@az0r`TAt zAJ0H?FJyu9cZtb-Xd^Sr0?J?C@I_xU`(-{l{5b-6O{ z?Y>{z9dYMLwTsunXP<%@@V=J4mgbxWwqWPIR7aPm@K3GMSG{{9gVCVSfj?ThTAB-~ zu)Yj0^s}Of7wVwB2iprbveeTFBL;@nPAMDe{2y7C6iM#e|9|16f7hQ47W-#?vN%K~ z65|77Pg56o@X-N*b(5ku~9N)*HO#79bY*?4ryTl)sC@vtR3|Xl0yI!`a5CtwM3Aw`f7t3cI zY;;30o7u-`v$^?jpf*a8T#C0QS1{HmAC%4ws25rclHA4(LRK>$*w#LuuP2LYu}Fc- z$j!BOT_rAb-!I5p!qI2|7Ud&-4XJHs2(&0coWRcd@5lKcFC1%wzk`7dj>zicVcSjT z8=7T^ZdDA}eS7L<&W-AoC;9_w*AC9oYHq6fee|ma^)fhfzMDVRGje)2zOQf3wf0~! ziw?>3WAsbF#1UbQ#Q zn9&x|>TnTl0WyPM_!g^h*p|7&v6dNC>-yU*o9#Q)%^9YV|MHaB+@z#KPsgZpa$1U0 z0+O+Xxq+k&GNxY#tcml~qnV#{){LLGZZG2j>U72hATqMd346Wy0qSp0S;ccoFdF<0 zWk4qsT_S70@%UKWULX!HmE5y=z}N-7jqN`P=w}LO3N-8?6KtdA$mS>5eqC% z5$7%Ao=)c#M$GgD9Q^L(|7e}xBh1@LVQJR#bkQ)!1jaJOrEB`Hl&y~yLBs#0? z#0VJE52XkL@!C}pT8N9%o9Od4AyXysD-*RTG|>j18ufVAjFaYR^DvqDtP3NfK3lJ* z#UI<1qa{FE^2D9eCO<|xsmiZ-MI2}tD`9t~X@iTtiTJ9g5s-0~M(UV-pjWNgX;^*7 zU;(_mmPPwf4Xj;x9!4IqtsS^fS9Hh5{D<_?!g^9kfC}E!Ta5ta*jb&i-~31rU<)pc zcHg|aoBiCq+rNsD|2z->k_EgYvkr}TKPxT)h`s^r)W1AH{mxSqxoZpnl}f_Ac?1Qn zCrHsz{wwHTr}>wnciljpGCu7X;BLTI8%Z23==@rR^1Hn%z8i=i%1nE1Vf(?=Ieoxy zSbeGj(xa^@YijIkCWsrDwX1JZD}R`Q>v61vRB3+$vahvO+sd64$Bf=Z`bvj@Q9UBc zrThx2;1C9$F}&hP>!cQopFg;JjY+RjK6p*xAF}dOdoKP3i~Looe3mUb&Z^9hkCj|H zB2sjYLw!A+>w}F+qvd{GzCdtW@nt$K=1Sb2lhjquanL}>%_ds+1zoBSwK6PA~u&THf_3*N{a@Y!-e3^(WBl+D=r>s7NqF8_-Vl&oid?oaX`B5;dgQ+V zyS34s=Q9h+JI01C@is@4VPimxQ8QF2xgBWNUt5*i``MSJMRO4=<3F<$&JB%I#-`4xb783V2&bHe-94+ai%}WAY?963vp$w`ST+W&{7(a_$6^C zfnmRzy&M{67ku|Kh0n_!>M2D>&x?#;}rjkFXH+CRJEt9rFQ6!7#2+ zpT)ve5j(LOY3i$F^=?%rk`c?4q7#t}R8Ow@nkIFry|*&}tU*wN2?`#5jYaanc$Yh&teSNR{=uz7` zRa*X$V)Ten{MA6MiB>!DR~6j1qbQq z(9T9yj`9y3HtYZAnJ>x#m++oiSp%dj9VSeRog<8{h0f&}sJ)Q^-Rho#JIMepn5jpo zBUV;WImlHk)0p~(hIv`NO2uklCXOW5BE#;7TZFJvI z34-W$jX`fF%H||_p#Ah+kP0EJ4y?FbaMO|fK{wL;+yV6sjbkR1_3}@!4U10r>gIGT z?b7PN3Z;~6OUgTN`iHHj(!+n&KlwJAT>pBGZYuzoCqd~%Mj{=_WHN4DW;kFgtsWz( zx^?M7ex&v}=*}0N-fX?*;5v%h!5_xgz$Vbt7rE+BKQOKhHGG1Xz2t1Y2ll=B^_zj8 zHaAMsKU=o)oUTPK!3KD%kvOtPl9&MUO9sFPlUX3K#rwWMLVGZR zD)Cto?K8xFTaK{%nlqtS8L4kpHc#_nXhPmTTaZYIt$L zSNZ=@z6L9Rh!0DVKo+nfu7Nj8COAJ;JA`H$&#YAAXi)e?Z3*!1PWzjGlF@hpAve#%HZ?XLm0o6=;~zFRX87utMIuH^Z{yto zd9}EYl09Muck~XNs_HB0Padqu*0!8v||Yr%{6#xC26(J^G_nx zBa^Q$OYXkGx{udA~(7dO^U)f&OXrT{^NK;0y({lo_CXikrUiB_lp$PBgA^ zD%!FKBr@Fz8P1HA;L_Jf$ZT>~od=QfedWITSBhl?JP%(&bA@#elJY^9_ zB&X|kA``QxfA`{)?AN}xG_+(iH$gyO(B6_}L2tLRh-;FI*-r}qv^kR9!bgkeAWu{S zQqi+`5(=uBO9OX{w7I38TJQ)+)hGd%62NVXB)MDcME}A9bgw8y8mATNAS0s#l*8WP z0OI(`h^gjjK5{WVzEjER*xq6HK+Nlv{K@nkd6iQ$ttUL@`U2&xV+q1IRfHRc7C=`aK24+} zb46P~rqR6^wkUS>2~Tn)w>8In0uV$iJbPBKVL#KYqbjbuew${$sGmmqdNmzU1O&+1ho?A-zipcs*7ry$)HHnoSQnmY7>=AmihbSbGEIaRT z8QylWB?*&@cVADmj`iU^qk@ID&zqoOX;?>^$6sFXd%Xve{`;MM4-e=vi;u{O3dg^w8kVWdI&_K8)pE31J~OsC;`y^8Yh~xX+;(H# zin=DyGUP(#Skl<#l@`zbO&d$k>gprV@#oj=L+B^aXPB#OMP|+54LvN*6T<2(=bnk* zskB8s8oc#LdPEB&7v-E?r5tkvyQp)|g>hWmXq}&4k?=9L$09?6+^U?3aSNwA>UNvi z#pKO-2#{TrX53tT5CYRqQU7}C|IP1!KgF{MsC92ZFa?+!h#TCJVV&@Xh64g-iMxya ztvN^~)5Wi?^F0$~@gd|1#;m07$;t#s6_Bk5YG~hMJz@8G(F5L+mtq4mF_{5G3qypV zb;xuZNW@xP^8!<9R&zUv;Jqt5-msv^F;shQQ@qjD@QJbBVKLu9Jqfq(oz-tL9YQQ? zbF28y4Aw56pB5Nz6K3pF%4IQG8tX*l8D4ruw!4Rune=>SU(0o)g@GC>z7m#vqdcL1 zt>=U@Pk!W28xt0_!SkS9m*WHY&J4af43fDD)*_0`%b?DE#4DqM*f33#*vi5xzo<*- zOC-x{v2tXRKZ=%Q(5+^_q3x1)N2%Agzp?Ytop>K7p>qY(Nd4tMg+`Lc-2Qfh|4-}j z3-l@~1O=nhH$-#A)R!OycxoZ5Rdn4Ay7r$|e1N8d&)JKZ3u|vSeRebt4>CQ;Uv7EB zdVRByOIU*bn@c64@NCE+HOe{ATv?i89$v9BJObHio)2=jxhuNEcLphKk`J_I`MQZH z=$!ldQM&!6PLf=U&{Rg?HrJgfd9+Cx7JII65PTrrj^2-LBaIgwqZL2hW3eVB(j)e* z!wj6#Krf5Q0xb zj#Vy5xEGW^7$lmTYvc{L4rO8W5lzY^igK};L^v?jR?XCGfA!?1Fhc2{OuqdOqfEsKkJ_k%x%p zp^`fTig~fdI8jZ$nPK&qwKLRc#;VsgQ7U9mMj){~)_9DiQRo4^cfR9y^{dY2+I)8- z=_xRboEjkGn@jI|HF#qqcG>W~j3mbA`zJ=r3+gSXU-ji7MXp>=l0yp-Bi$=F$D83L zlx>7LV&Svnc7SLVzvS$#ossNgl0{RfX7?yLilL+3+7-+%yOSw*Q)GZ>hp8p-u3U5Y z6e&HTa|t9jdF0EX#84R;T?8SZ>bL;|PN<6*jaSpx9Qc#j>Z;OZfvXQSLojr0s$^ofqK`XVe%IY z@t;E4TZB{W#i}0m!jw`f>dOiExdLYISP(HdV4 zFdZkRdx6vg{=vpGU@uqOr*Df`d#mBV?$nz(ei(@Re%i=h0bn^R=cNW*laYBF+P*Rv zf9o%H$;R$5pcqBOLBsN-?lP;Br}}=!}*`2 z&)~E)3@va|STT;Iw2bt(jzZTp77w;vebJdBh#Yyoxq9Uns)*5kNvams07(x-q`tYE z%x;83B_VHJW2U&F_fl5BBYPy~T{OTL- zqb?deCnHx`pz^W2p((uqXtz3NiuRU_aEqj)pw~(i{D<`XNc=#>CT8oA2i$&B>yeCb zRHQek@RVY3@sEzWWW@qudP_MRP< z4o>gx5S~;(T3+Cd2j}wXivsHwy1BK0I!GivM+Jb6yAWCCs!}(S@(HipHKK*~&c=B! zQU&^)5Q8A5%zfR*L{JBr6yoED-W7)OOc;&V^aEjM16oYQtA?_Ij%II&n0y8$K=r_p z_mvXblEb|HylPd6%U0cVK!#GoZ(G+tK=3Y#c@Da4fl*S8j;x+1tpeYgYt1hxb+Tzl z^d}$I=A%VP`#$2tRaIFEKJJm==~~y3hc0a+t;%1|Hb!+6TV2WX`!q{-D%8_9bcH^<`%;Or?5QlzQ+@QS`f8aoCf!vdw;RprT2pD+www znmjUZ5sMa(U`}?lx{NozKIMfZf{9-)h-YJt@ z5EpiB2Z>G*75Pj!ZGZ^rvVUJa;IzHws~;r~9rQn`*i9fKD4V@TSXG?A->@G?U4pU9 z-KKH~9={ffOkH9&W=S3%llH8fz=M?j3hPg-k>ITEQ$YhE7+jHm(vJ?mC!_vMaJ+i! zfvc6%p_f4G2(bPZ?eG8wkrz2)wJ1w%Lk;n_cAKuB3dW2s>K)GNd>zqpV;w(~y@N31 zpXr8KMIP}gCg*731Yqkti(>goc&qnht-mcF2;h*KfV@O z4Y=ekv~;(6=`1DP)={anNc7&&!QQH9uWC(-OsQnSjCmbGVNb{ZGg&sc6UZ?J$6b&M zSmp)#_SxdE&%o=r7PTeif^M*_@1r=&o%nEpY#&67X4NM%%S|xSn8loBCATp(zhympUJAGua1}Zl{A+-pux3?lYZ`*LJmK^y~pbUK2 z)BeIo5kKpQ1sbnh{dx~QUPy?4Lb2se`49;iKSw`g zxH@?$3KQ%T8!fw?;g9ZJFgYf}s77d{QJAwT3G4m-M_eu+*G^%!yCB8sOO_1RR)z(Q zJf+%_tTsc<3#?i!;JMds@|iY1$n0OwStxWmp$#cqjOOIe0@e!(6h5BSNxSe2QYU@U zydt1?)@yF2660l#sGpz_koGr_1@u-AhUfhE9qCiCYaX3S>P8wVnW`)*GZ)IBd#+Ee ztfMy>S?l{U7{2oF`0@x-IMmci*C~~rSo}^-MQm7+X$LcV!R7{mhJvQluXPp7rI;FI zYg`KaaQw4VfJaKC9ClWv5!Sg2*EnN~8@=OyjtU&SBt1cCtJ$2+tG)JU$gur;G-Ro{ zDp6tnHqUr{k`G5k4p$fIH9w7{Y7(AOvnhKgPTnB(W~(~V`K`w&bBei9JU#C}!|1O} zb}6N-6jIGI39i@dO9!2A8D8f808sIQN2yQ^ih~{lZPs^ozQ%>$Ww53RzUm>9g0qOd z#-wj!RLU7+#8elgbKc|@o~Ey4dMTQ@ohJL0PC$~Ybuoq>D*+V2<^TF6!8qhLKA}`} z9Lc`TJoSRuDJ~*65b6_1cF@RRK&bqKE*odN27i9bg-qTQ!x=z0UYsFmwmGSnF2f*^JqU%4}Qs`ccTGPBc5G~9x_ zC3RA13AO~L6?}9z@P<~%)@U^Pu#)fB6NDKHx3CL((}5(CQg3l7RtX)(SfwjT3L-NL zix=jgii1*~0`FPT^8m+e6a@1*hk>am=4N1aZaoUV=!C0?4YD(T_NPEm#woF1gE@m7 zQ&sh0z^QG$Iw@yjQsU{1(g=p==zt413U3FndRsI+W!5MAJsPgTvnz@r_AL>G6^2sc z$+;2)`pyIDYTw{mB42D)MiM#UsC;KdH)ECQdD?nK7V4VNp1#GnkBb-567GgoZep!U z73!YgoN63}?C>K30Zv5HWN_y~%9C3a`oj7ne*Am0Wq z6hd!d7*3eDvF$aV54f?NoU(<_X?jcWY#RgS_UvFUa}}ljI$e!1`r1i;cAKC`tbC1U#F59m!k^V1}-PJcf7XnJ6Ge z7RFPoCdOt_UQ)m=f7xp*maz`~u+Km}!G^|;Q&Etv6$-8*)NInvQJx<<5f6kAAosS% z61m!%%=EZCZKFuYE2Msz_AO!DXJXU020Ll@Q=^wmc4Rpxz-$Hun!w{77cl+p*f?7H zv0n1$3_UI+x&0H~yT42LriRx-O0VqBD6VHwW5aGv2f20Vd1vXRR&^)0jdc7RNV*aG zrDI^Us32wVzVz6OGkj!OC!-4HbR|Relv*JC-ZLM2qKymU7nXaavM*GcU=XsrI-`x( z`HA|dp4WcWci1mb_)9f}ShuugA4$G6<`sueEWBwN=$d0*Xa*sbFQJ^RV%1j>-j^A{ zgRTYvuDC>RDR@cMz_Ol8#iCB{>*6%OIpP^gz*<|1y2*`zj;f%tU% zdeps}QAk34Ib9?1{!Q)cD?9i@M1<4M&IW+7_RN%_rcO1(l$q$Z7$3Ng=XY57hWd;X zjZl7Ny)#g-faYRkNjI@z1So`&QIBrUD-w{gbspqVSi^t$y0=E(TkAoFmooNRFfn$Q z39O;&r^YrbT(TaMP>kEuS1rs5ItuS=?T@oWUZA(hVkGH4%vORva|W~N(__FGO6g6I zMqRL}^*x7QBQupxZQqCGEqDrCC&+BlkUzXm4EbxIdVQ_LHQ!y{&9fBIgFTM-T430p zuI;A?+eP(!B13xY&oDb4TH9G|*X+2j3!cWnt-Dv$3{L8Tz&3@9?j!0_j;1F3c?9iQ zDdDS(gS8 z$iOr8I|b&;OX4tu$H?|Gl@p26WK7`-l51$ACM|l+l!=EEH_oMaki^_Tm@O+b2xPAN zm)pkbtb89LC79KbS^g0xmf~n_cs#w_dQqJkv7!^c-L&mUBTX}4jIC`T)2V+rJaYOE zKDzch8A@H;(_Q$9MV*!!8q+<~TDW_-`Rhy%@Mu8quw?Kc38nLhOT^P4ceTQeJFEIE zJI4si67!gUB}xF7-68;VJ&9$J1RU^j>B|jBZ2L8ppNZzw&m@B?nX_;m^?S?R0o={6 zm86mL1Fb3%7sBjd^u_##QTF7r46Iq_)~JfW?q^iO^DjT-4*Uf<2g~fa_)l|OLk*-h z`J`oHtMS6=0B`H~z-@kTTQ4CuDtnw{(9Z5pJfgF(>dS7)ss5kbCxwAM66D(yod|}+ zXji!?Ab<@mxPp}=^pION#u9~^@>7-7Rn@raWjImWEB{qwX@9J~OJrfIe6sx>poD7z zK?}g4R-eC6vZ6Iq2-3ap3ZEi>vgSn}!$vq<7eSd9M`*OlBIFT7n(W1ZFgLtqAZcJd z0mtx)07uPKqC-pP1g%SxDa$eAJND4l9$BAgb6 z9@R`JZDva{#|4vVkSuNMkSPmVSj_+ye_f<7Q4_k(1H{PqPCr$ZrC0+}i@VZhT=^%3 z=`;&q^)IxB$bylMYQWy`iZ%74WjK%7DDpJSHo3nd#i0y4(#!x&5cOND#;b`_*Ml4Z zBy?;F7|7lIjI)h|S89#bM0`9nXl-7&4ciWdv2Op66KJ8nt6v2y% z@FQ4(C+z2N=U_gmv@mV&vEVX&7ymjTV)Shb60k$9l`Pb$ic!ND0W~ufrW$C$O`Ta( zE)9k!ScMnNrYBu9r0dR$0wolGt41X~7YzizqS_vh85*UZwlk`SFZN1oT=$#q$;g(r zv+Owq+YUcwTh?ylVQz6p1})X^Sng@=;#A|mSy<*U27*B1iyNV_-&@K@8h;PqIUlKMHSZo7Ffk?s(*y0N-%5f&t1dO{up=iC zD9?cJAUa0Y`jcF_bLUJ&U-RRTd37%Uv>4$qxE-=y<>@b=aI5Rqsr?G0k1HihGPIS? zHa=jvS;{%Bnh}e->)lKKs<~q&pP^_$)7LhZ)-Z}1SKUs^iN4X$M|ZLa!ut20TEPA_z_ zO~>h{)`qKy$=EE+&bDq$y9pX(9izxw+3PWwRrzduS1b9KhN0AOTFm!xPI1QLowT

    #|A8<>7%D$E_-`KznM^fc-3%A?)Ub;8?rD# z*J9Z^J_3&ygZk~;A0Az>8h4k*!j*hy-dbf>X&u>9umatTAxiW5k0#R|cBrrC4L|Ba zYjR$ZH2SnDPjE2{SQY&vG0*17&>X#`&i+cpnN#JL9uKxYO5tXWJ1NJ>z(!R@BpyG< zy&6Qyr{Nh-+OB%pKQ zB&5y6me$y8%AveBJMz07p63D{!C0USpJZ(*7<^@mj_T}K*S+U)AnCuRDHgqD5h+^~ zP?4aZ1GrDU4hj9A{*=Z!o#TFA18Ma-fGOvk*A{j`5h^si<}px@QKJ;7S7AN|!z?7Y z241KlQGqF#i9UrCTP&{Dj!dGZ$C(R>i1<2RaPnd6APGWORnQPssz>?z^a5eVo&J|wm zJL-)N=Esk}<@x}EMzjX?bV0UTgEip2_ZDMVR$3rTZObgQ4sRHq^PgeTuB{IjX4s&b zm9xsmKsZ9rLBZ<+(z**918a~@`!i2NwpW2JSL;&`?k_Ch)#^;aW{Y2k#>j*5udY%P z^)|X-dc;U8`JArzVsK*Y2XC^}(}zbf-se~Yw6wvMXBn}&Dr=L_Rslga>ptNz6LeWW zI$qs3bRX(8nqg5w2 z<*}z8q~P=Pga`F|%6sx>yi$9WDWl8S?yPT3*8|||OO&e+fAq$u7Y;in_o1-;6xJiz zl?b;{8fo#pPaxHP&3I%v@Z0oxe3I$##sJkdW7r=$>}Pnd0PovQ{XXwdV}_TRPAw-7 zs%ymdM`57EZUcv4*T8oCPWoky8=PRt@58# zef%iduo;6t17m`{c*E{lm_uX=z2uV~C6dLEqn-4$5v_M9>y$c6pLmqf~z z??124f3Pv^=WWm}fG&uvp45vvha7X{p5eMFe53m;GYT>_g&mw%rl~b6e-)z`4$Aln zjOWXKKxJ>`yEuh&Tg5?!y-s*=3}il71kLCSz*6i4aH;-_~w0o$CTFb1BRFBK)NitJBNNH6UMw*?9VVph)Ss;yyOgqq!6$@dNH zEo6j|!a9bi~O6v@}>$!%b7$&UjQym0aoS zzlQnczX@vnd$ID-Kz1nn9h`+=U;mt_t%_zku!T2B)Qqj27XvA_~r6w3-j1S z(pJYS^f$dc)Tq%tw{Y+j$n_~~DRIZluGSnuvLh{~^9JWZ(zS6VM_e+sG$nQOwfJZ3 z7bEAz!uai1+Y;}@l~fQ*%+MUFi5 z;(6B|gC2IEYc~`gzgg>ByIj#LpJvJ-Ra+d|MmiN@Pz5kKI?K@rtf%>}`GPCK$=u9; zAbbV2HTig*_kB%h7c#s0h?V?F@8SHcFc7SnJ$vkV~`3$iRrlWT8sbMf-Tal$9*F1FBy@Rv6MQoQN0PaEFye zADJ-^e>bxJW&5YKs@S|8ftvb}oUMgf=_m=XXf+%tELU$CurT$k?_oqze^1uYpvn5@ zT>i5m_-V|h^MBJj9p}%m)sCX>&6D)1OUJkjQZLt$a!#4nd;fVe0uFm@H;-HR6}pE&QhiN9}bRr5+t!hm)i!*h@1Y2dAv z&fD6pLs~n7PC!|HcjsuFzUBc=kq=G>EYNsOJ`nGA?~J+{3U^m1JmULOGxzWwEU>)Z zlorlS_#RiDF;_%2 zxc;sZe_7*fkDNF8pZ>1l1<<5R7qv3)tTNYRHC{8mB5Rcp%y?5iM>GKh#N~)V;rE9R zX;52~q<-nNfGfiMp52Vl=}HMfl?c;}bgpmO=VgSCbXf52L7NB)M9g>t0jGeZE=S4z zJyk=keu{;GOJbwmAi~241%{}_yvNly{>+kvU(kMBXyF6=IPc6j`r3B8E>dG|1B3r? z`6GU8z}lix4utKxA8SbD_LKJLLF~Te6SO*Xp{^l?G}fO_8_9-Ud=^+>7e%-eXErW7 z=1!?~X0C_Z4ckUsx{hm(c=`-Dn#1*Ya=yn9RnzHA2QsvQRMVznu%-57PrQ_^n9=Jq z>{s^Me)pDo#@qWi>I#SRqtuqe`#*?#KBIE_`KQOGGt%{YUQTXbaOuHgHEYK0rUh%n z+2;hFve`YA{c>^WNP_H<;%MiW1tpVC=PK+M9Gw@wz0SVV$yoecS?1V3LL+%drd_5A z&qJLGeCjGU2sDX~WwLSD{xn%+Qqnt~Qs!4OKP!Bu^2W_M$ul<{6l)C-bZ$>&Ig!0h zK?26Y@k->q1$kr6KV(Z{ko7dRL0G&))+Lx`nkV7-1Xj4_9W(yU7?yk0v<#? zzp_%(UnYFBSAN|66W{zH1kxZ3oW9la3AW3gMgyxr7is#vXQzC(3r=Y;B-z zv_+?5`9Y6qGwH7b*0Dp0OEjl|;&wK~vm)5m9{)PzST5$0l@^s9J=McE^h&#Iz~A%OrsH z1CxsnF5j-uI2z!#TLZt8J-lguwo6^m=`Mo#kT~4uVRZ1D`9GK8KWj3J-0OK2ERX(q zb`O0wdb$K=i4twZ;WPSTbBalWIj>=b%CdA~6DV5~dizK)7~?&@Zm#}ZU~b1-J7Ib! zmxW7&G3VdQyK66Z_URzI^xdMWaFew19&Sn%m+0xA3*I{OB5_Viv8`_+QxA7Rs)l3N zs-TWqKlsKf5?1GOqQ|Z%IVuaMwLyPt%o#T}(voy|5zXDsE+a<#^)0c*=EgweAUldHOdPPR(;%&cr_?T@SH&et0l`p@=cs~x*yLU1|f_pOat9@)U zPW(Z1kqk#;{&G%T(K&ye+HV3Q`z{npt?c)Z(|I}d?(Ef~^9Qd8G_(k*ceDp(kdNw- zZBMRI6`GZ%c=S%Hr$_qG1`hcQ4`SBTh^J-A*E?XFxj7K;G7nTjlb`8XzAk zG4G`zR+vm<+)l%@X=YZTL)8R3#ubx~Bc~pYzW=UL;Im35@i;zO@Lcdy0dpK#5 z?ZkfKiT+HTCp*osonp2D5pdN;hxmpqPbd#eI^-Xi4;2ym0fkWAtI`zejhz%$#S%9g zYhNMz>DNVTVd`$B^c60KD2@Ec?owl|o~e|gmdm7OAMcWJnwTfM@eU6`&G!-Q!$J+U zrB;3KlU>uhR!-H)aRo2FHxic93OEcl!?$Mj-SX46F@36hKZ?&Pu18~h@;?buLp&(U z`(4y^9yB2{t|v`al<+JK&FoaFc~v4-x-l27I~VNIpN;jn06Vi!E-a0E0FZ2^c)S2r^PQj zYqpL{aDO-d5<0f4j-4#rrTixJ(8HOj>(`vdpK%}jVnt7OaDk6L-f;4sUds{-{oP*c zhmY^qHS?2|3yU{fey%n48wdn0X)^q8^JnCe1KPNoS6tW9(XwJ>&U-@a0khatau}`n z><(wErlAH=qQxg_*FZ0sFkl_*@khE)@{ zyhZohp1M{smrh$mE77@VWaQc@8iBb0#hGHf)9{|{E|kqpwR@I)C3sVaGK?o(o7<o@# zj%m9eN#QORadeaWnZ{W;;)RA)!p3ON&Q+%4scYyb!DH=AvZ>%zd4}*}$np*(x#1M$ z$A>V^;G<6^o4l9!W8S}Xxi_pzG^u|Tds2~f_{r`_b%NUBVOrWCy&B;?yKu=&rK*!F zd@VVz^ZZSHxKw%cj;oyxGxL%zE#CZI~az2WOA{ z>&X`vXFu}umg=Ty-MlfSIM(K|Ff&8l8Fh(Xw%^I)>s)T#NZ>*n*N{#j^`1xYDj)0T zbev@W{M%lAW%hyy^wsg3xm>}soVJ!(wOiF6qhEPtCH1Mhi4MtT`Sq4 zSp^b%`n9d&XR#I3&p+zMQw2%fxbfND3fBkZ*|e-6Xl}pejjGPW#b(Z%3imIQ<2MDX z-nAnpc#s;fTGO%|ohQBGS1(SlUzgET_h-xhe(fM*t7LBn+nPzR=0~Z7ognoUNn&iB z=xHJBvy-W8Uqoh3Lk~KR3*z$fud#_Jt$#V^4l(6c8?Bgfr4h(9C)xnJ5{``DIZoR& zpl9RZW7vaV2JHBmqQY-OpDJvzO+Mxbao$(Bb%r>gwt6})PhcxS$PT!@g zC4Q7(mq&TnM<-^sNIwP-iu_;?TW74lmmS^!57GNS$N?yfxW={g&FI*v%F=G5go<&l z`)Uo+NcQ%o3F6Ac_@dg9PrQqO z&0ag`cIA`@g@@oLtE%KIx9rsvV1ZR-uf1lYkcyT*>F2s~x#Br3*(a1ud0$$`*c?2K zbgTMd6dUi6|5e!i&vik~U`s2~5qP;v)}LDv>sLz0A?X2>ZZ4>;~T%vO`c7m)u_Y=EPfR%R!p&&`)z;JK3`M zBq~4h$wfPHQ(mTif8L$E2rhci@Ac<|?|T%=*0c$Xv)wb;HZ4BgNL%03NF+F#He&?U z6H(QD+#b)Xw~b22FDvB2INV6v-)ax0R(_U!dKvzSUEQ51JnW!|7@J)j5r?&15)%#- zDQx%Qa#Q3c^B07l8<8w|BV1&#lio{_cL}>4rgKxQ+dB27!(cvyYvq*JSF0g--NtfO z#rwe6e&?>=Due<%ieM=;E=4QhVo-M%J#4Z}*vdz)*6$3|eHM+x84Np%y!!kjL^IZh zh@%B2ac*+YpZKjm0hNIBT>jNMT5d`fbND&?~cW}A>W@Uw|`;sO0yh2Iu zcFw~=-M_E(Usu90IT7^VmjkFWZT3#5*6PTrZ_6G!IAkfSj8$JjW9*Ssos2eEGkRW0 zrgF1(K^UhEqw~^qFjN!t1Rdk*q(Q+^bePkdi%i-a(uG_zxv2QEnb)B|=FcCU>bAG9 zVZXz!Fa7&rnIG7w*Fr)C4jP)}YWui9c5JHft((t}FMixI*tL&-xk$m0Lq>4BiDtao zBA3qtkxTzWM2LMoh&#Oaz}Y=+DwW3WSVY(rvqae=n4j!lPBkPd{86PZH?GqBQd;@w zw;XHQX$=TxYPGL~m13$`7<%y&`jCz0*Pw>?uOPC=W=~6GnQ*5Y6w6C^!wzF{N&TrJk$!R7AsP?CRe^e*!suVkDMaT#OrEBJnt>{ zx{lCdVf8Z1-im5^G;y~woXve#@X6|khK%l0`4WwS5XPvC-Bnc3~u6MZN@#RdDz zpD9!2-={o;d3<_8QSH}vmUY5fRQCoP?&5u#!p1fe<@AGpcWa*w2}ep!os^C6xROTp zIGZ7_Xpsvq>hx}^E%Q&;pIV_ZR3~a(x#;(W36ZVJienkyz53G~O2qpw-w^mA*Ziv+ z{AW4AdLM!r4(~$sCD?49h01|~!d-FHJO=<0#I5>s7e% z%6a&fA)3+5e6_7^LF$`6Tr1{#ZL6ZNzB?Hqj+4qbn7i-iZnc&Vb`K{LYR63%ixye} zA8f_1q*og-M3OjzUoI9$B3}`lUj!YTd|G24JB7Telqz=z}0sF6r#@#4q6%8J0^*A&+AG zf7`9)#-I5yG+0yM@fQcr$tkOEtlBcTf7KzzEJi(|$&nO%F8G#|FVZbm+jspAwregx zdt(^JH<>f`{t(T@*;%T!*sirhF1eHE=nFrhS5<2&LhH_HepfpT{}oIp<@w7Cs&=I- zOI%6~gNsq?>~}5|wazpptR8ZCZ;(qpxTYkZ%m3r%&Aj?Pwt^N({eRyMW$i%|e{YZA zHz}aL)L)uG^+x_GLmR$-ZP>4NZsKf+i+EYMko4FsLvt0K98?^4mq)^4wSjoXrRtA2 zpy`;lKW@Gf4ya9;{0;dEDyAZOar4Sfi?nn7Qqvz9+J4;0hHns5P2bRHFG-Qm8L@PQ zhU>0DMTgGIel0f`iz5WS*vdPYF}&_4mhsi&4#v7+kKY_dw=jF(G z6tA;6V#dg@^Oh29KcC>I>gkJCyDKteS>kIb1A#LG#_WR>>h7}LS5Vd=nKU!Cs3i3o zS#u#3hvi#ZzG3?>rQjNQ-luX-C#i5}xD>=r-j@nF4q1KTj{CGvO6V)Rpd<9K%y}*M z#3-ps-hc!Fv8&Z4Pk6o_pMB+b;M&F>zqHQ^KPMz-^0CZO*zeFv=f0`o1~ml%b1@sc z!$paKAi30kT#h_BDbx;+CHW59P^t3Xo}Qj3bz(<&+$R61P@#ve)hKu^ z{a$fCbS8Q3l~D3jv=AL0waYSZ+1LNELTIa>43s8-GiT@Uik_dC@zY=bL)O^ewbDTO zJ)a%A`k3rqY6#n@YyV#Qzc8UB+r1!{9QMNCeWMe%v;BV1suN1t4mSYqsq9U5?VO2q#P2ApW14ACp^~Dmm^t`+D{|{T= z9mw|D_HPxf6{YspR)?B3V#U_eK~*WGW{uXC+QimFgkyAe+nVYKYS>_|k z_kVWM)JaKwB*Z!GVtnU%12ZTsq&so2k>34VSD1EKLC7Tz?kXG3%GnPWo2Eo| z=pt-K3QkavLn{qD&FcAz^kVJ_P$@tl^tr~CTMOLV#YDn59gvijs4k%J6vm4ZL1es+(Sp0 z?m&M0>Ok1N>}t`HNh0y)2_oNJ6*fXT_Nxi2&9qgjQ5wxxITCYUn2*fKwwzUbMqFsK zM&vbzEE_s%57L9s=tb2V44xW_LQ;OJx0l*h+)U*+>Q9*8jkX+!FE8Fa#aA}JTG!=P z_8UpGlLz=W=IBJ~kC8CN=T-EO70^!I2{FrJiZgRZ5rDkQUV&z_-GZm_aDo zsXVZkg&dbg(8>lsQB~LF+4EO%jVAvLH^z~kr3WbI!60DkvNN*^Yki`#CVGu6lJ z8r7m;($Ta0@@j94-gFW!fUU&7D=1v8-3#_aeV7kpyFnyyWVsEw78XehKf~c zt;46ltHyqN11`9YG#r@nBSyKNzHeW)b@eF(&nTYq5*iDG94>BV2gQUue2*g1+<%eZ zb?#UQwBxn{?;I;)X6Cb&Lc=HW&U0fziByQo6is8#j;OiKss{JzMYHWKB}_3E{%Ox% zb!kN=T?*(UTu@(o6T{V#gm>-Ay~1{3Va_8k&;Z7AMX9koi569Fiy6}rR!R|>DP)r= z=Yjot^95m1|Ix0V5r+_9(&r0R#!uxoA7pBrx+Pv?nET4H5dr-vTtBb<#eT#8#?(tk z#w1(MKhSyOr2q~PM618Q6j!=RmmTIw|1*BkozyE1aYE^ujdW&<#q)kB5I}+Z^bjLtK=0Ws|1hT{5ejXle4@Q!aj(PnurjoKI6hutYh_23Lg#)(%4&qG{a7GolXcJK^@6gx;(wx0b;f2R=Ib z*7jSWoM!P+v@SnlyL`Ej`eNv>3=A;ALa6)nxqDBrxS&O80t9wizHU=ePQT>4A*syS z?8Cn$){K-5w~T%V`q1UC<=8^w<~1I%*&kswRgF`9$hKf%M#!d~J4{l+lL&dIU3nav zjKTgijvkTuy829Sc|+K7suHzjxh)4`yF{@4)zL9&$!{VZiu>xUoGskYG12Pba#+fM zMdNLnUm>=Oa&=p4!CL$tRwAy|sb&j0{XcKcq>XWud#5qj)u^q)(&p>H(aP{MNBY^8U2wpVz+LgKWW?U@jOwXtGh$Sl{8$*ha1Ctd!IBqlxn0TtH`% zvR%~1aYRBQUlqdU=;j-AB$oEQcU&p$9Oz*L-x9fvjg8kD7ID38h1RCyU?^E5=Ubgcf-NaZXG8j7hGergmVLy#CFtn!^ADk zS1jHXPI}do@8a@jjecI7y)W8=Y@>tDd1hAOewp+CsQaXYra?7HYpEEn)6TWAZyXn| zgPnu;YHe|~0ZuFI#~tFr?z`CgLAxNjEl_}9i_k1EOvWE<3==pPt~uG}Gab4Q>SeZW z>az%3w0Y&P^k*Ji@KYDmoi!)#B>l7To7oXIoO88ycobHYH8nyVp}mzr080luhR29C z2{Uf-Vy?5js!XTgp_sMNbs1HW5~olO-f#mQuLB`H zZmV;l&pG|nuu&i;fs6FKCd1wibU3Ayy}bu-i?)w?BmA97i+b357QPdFftC9Q!pLIy zZ8y_q^+nZA&7q z8an2psGyKVceUso{>SA{>3!@(dT+rwcBMxDqzJj5Z^D(2y5(KmPow1RH%#Vd21ygGQdZ zPQq}o#PdX<*}lE2CvV&`A)5l$;_?)>Erx{p^GwWUh9RlA_g=AM6@bumTL@o3gEZe4 z8o3Z*xn?6y(KJkF$ru~Z=qz3`kay0dP(?_R-MO^Hwcj<^e-$rGZIn~@O}%Mzdb65U z1wc}ni+T1tK_|(xx2ardXXmC#i-GyG+Ea}|@}jC3lr{h_u$cHe4DmU5nc zd;Y-5W9ON$cet;u3N2-Ox15R{?KYX*9zra3aeBEqV$lerT+*;pp}!t zb^I`m^DDHZ|6xesn!o=>wWqgB_2fqVqzk}0z<%+o3kPG*3m7}qUEh8^Jmk_MpWE~{ z!EgSHQ=;p&*H8-&3Ja9YjjMz$PZK`kdDN}WNvGo+?gZxdRL*aMQhkeOE5v8k@i~md zg`*`3Sy?+9RN{Eqw7*B&mF8}trUElWj?AW2aqc9Zyd@FURF1{2O#C{)+I@+fpO)9? zEb#bEXa2qqbC)eub^fU!_No+0_UiBc#cBY=IF*8}P%57TqWavS4w>Q4$5t!|P{WP- zv_^Q_GfiV#drPJ=eB9*y@u?^xB`o0C?FE`x!gQkEJHHK$XR5RD33xshuBj0gLfE_V!E`nQP+G%HzeHlgYLBlrW1dS7j298K`F(mYD zYTWjt6RfXN#qpFzR+;Z>BoXrS_k~; z08b5|g4!9CY)xK;xbekJANk%_kjW5u4_s3QQ_&8ab?fHAbids;Qd^J0RJoK6t3g*1 zQEb=dK{ypEJ9!g-Qh>r^m zRu0TS%Qh}f^G|N5*g6V&@`Xw;=wAY>xqQYu{;JhkYmchayRda5oWqmM6I4a+QogvN zzLFajm?yoR7`1s@c~L`*V4%vgKd?FN^lM{E{yfr0Hj>I!DU`=D!0?0{@-yxGWlETq z3bHE6LImWYy;1Nb=aUwrA5H(cGbU2iLOQ=Q=KlyHx&lB0$x+=lJOrfd!rz7CVJbO$ zR<+RMqNf_L@0@dBc^_HYES}E>nH9p7ch|dQk3jj$k}RR3PG_SVzaOphR)Apy_o7my zpLt$Wa4s1%n8`B)Wyr}q3lN(7uOU>PR2 zR!!9|kLY2TF@U?inAclIJ_F>Qf@PWLq014I(C|p z_RRw~F&A`a#Ft+`xd+aH*95;&_7mZ?v+9yulIy`xG7q^v4ZX31s1it)pd6_xFVL>EzDHJU)WTg`f?lH_@LQ;fC^whXd_$T)D($V9z_2! zOc5%hd5=-kv~4UtI@m{yaa5_~^<_duRP7C@HJNAezCmF>G0)^&6M6q;beU%4ulT1^pcC_r4rSiN0Hs4Us?k+PGb0relPxM$Fb)9z>-^^A7J~U zuEeo(&W)-De_|3NZ&16@x$%zV6c{-C9>2+*?Rq;Gmzs29N~)p^dXdL__>^%a=x{K7 zZITp}u{qq@x#L&n8PYRI3;;eeOi)VHXLgn)eYI|gFA6OUCpGF#ea&4R)tDZFjWucpx*Vfc4OSQfK}oMu-1d7CL3MJR!5xviWQ8Rpa4!#4AV9W|(P zLySNd*x|cYSo3$G?)Rg9MfO`@_{}r`rU&3H&(Y*>ThA?7suuoT;X~{=RUN)L#UD-D zkfE82<%{iv#aP`0xr({6x8K3~ax3Svi`fY&v5J|_O5&PkcDH=Q$YFnsAO(meC_kA166ovspV3S+Pcke@tJzbBu913=p$>r@qhA z?J9O#qkgwExwQ)Hnro9@+HkwobK~g19;e=Tc9`mlZD7y`^6(eIa6vJp(-x}(LkcYDHp&WBl5Gho3O%C z1=kOT$AffxCHtMums)iX>3Mltr;~q|0UfbU!D<#1XYw&8xw9+=NVt@(W(Y%~9IOsI+@S8g2NYZ@v2 zR~rT%^KUncHSl=jIEVg+W%#S1-|i&~6$ls0?;Wh2e&#UjRRFs;*x~Dw2sW}XkC{XG&RyM8yu33_u5OnKR1Rvd*{DX9X*>> z>U~_YSrKBpvleMcT)}RYoA-8krv<*O$pSJjaHb%aOz^>C&DeK2t{v=oZrXLX#^+lF zgKsd*oy#d#gBW z6*9aAZ?@N+bVa?!q5Z`Awic7~*qrx1^A53Dml|=$&lYK)mQ_D3yp}iv)_B?4!*?j; zvLn+Fuwj$Rv+>L3U4|P3ZUAAON3augU`_}S6SpNqvyW}VaVCR2+$|qrk5k&6Qb2u> z0$S@}A^ocH@5lIeEzYL_#KUqLzR;(@k>$5?4*vN2?}7J8hM<#YF8ttftx^sCO6YJk zJqsTdeSPV+(qwT}@vOq7ZW-x9nunu0H~1y5`-_kJSf~{jK9o$60jSv1d7Lmmh&?{61zITBv1RdD&_nDOZWjSE?r0n&9~lp6qB za-YATTADYbz^ge`WImt1C&uAbPKKCp)=IAITf>KkACKT?J=Kx43X6}X+6rN*P>B#+WnFm zwmeeh+%g1%)+yL5zO>`lkm9wz&Rn{!jKZUy)XKcIT71JnQ#6Y_l@=+dc1KjLaK}r5 zziF#1%*!LCFvq70K-khh1{8mD0l(+u-waI=+LkMO6_Y=*qrLPASjUn$; z#dlv{S12@y7qL|17en4;d2|kQLiK9g><_;L%%jT!j&HzAB7_grHP)n$Qa&-BXgtgUJW?RghLiO87oLbs3u4H=Fk-Cg$JleDr%T&iv4$ zlb_;?XUMA9rOnzMT%Pg+x0pLgN4o>1w$TT`qTT_*wA{>#scxJsX2_ule*_hY@TB~@ zGJaHJzuch+;GC_M@@H9-nTqP}=7#geeN4$y4+GU;?*#U6LDoeK=AM!5W22o(T=mCu zdTXe?2emK?xT1q3xEZL78dzNR{FL)nz4z_`_OUPTSP2h@yb#;bU?+QR0G+^5OQTmv z^QrP?<&_xIi_MQk=aA+R21h>jFycCYzw3kN(+rQ9r3EV%K7zX+GhurvLCxexyx15U zWV$aWUd<;Q5*;|MH*NzNO*r}#mM4KG3pjw%>8vgCp@M8ErMm<@KY2HW&A^9*r?zrK zy}n##I+*BvYkbxN+y3YdPQ3YV+HFvn{!Z|iFq69TU&GH~3VF^yO0M0X8c?inX? z%*RZ7&b+vt?X|f%K$6dDS_gXh7;Pcwp>oqJ9fw?59nG4u7Rfq;sLB1Q$ZM{FfhmGL zl+hnOP)Kq15y47sV_AoRfaoh5p6sn}9~s{T%veWN(0P7e>dHeVmAdJZ{WQ#EU2fAj zpHj6ls!Z5ldV^qVIpz4P5*ihXpu<4C`)-Aa{j7??*hjUvSmY9$BMT6Xj5Cgt=Dt&D zo2E+4CeC<2+Vg;7^NEesH(6LN;k(EDoU4;*D&1D`$sgtm%=l)zU)N>#m)X|`h&j21 zaY0)uNZHY?QXlLi`KEc_Gj(>;-L^N6@Sw}6MsS1N@Hp8yjy({bDrjWR^PrM#l7k>U&T3nmeP z7vpjeCO$j|J(ErW!a>`z+e0n?{rvuJgx0DgE#SbPAmi8myZoNZ0ma98*b&M(Klxxw19P`JC8)khV(d_}>0*6UTwx^sL z!%Om&fl(pq_Y0uX2A^9svxy8tMv7Znh$SsI?b-ggqbx+H-aeYOYk{Spr#v~wPf2e^ zcw1&Wi3pa%(I|ELv^q4u>auPO82I3|Qt(zNz0b?#E0mnOpyUdRT!Ek&R^d0bPh#>n-~awR)duJGY7;{|u~)R%cN$|)H2YO)Fk^MVzf8a=5r)Tf@>5Q=PB`ouLE%fBZd zbQ9rD-_QZew23>;GGW0;|FM@*Y1qr#Es4wdbn(ges9ZI}po6~wX`85}>aBmVocU2_ zT{DU=@+tFQOjgJVfKC7NN3ys5PHbQH#)hv!7%&&!H&yKL7!-(T>XTWqs`U4^bCe;- z*Y-bpAZ8hryG%be-=`>QX0dw5F26Djto_v3fDRe12xFo+ z9;V(5D8d2upk*X?)g&}l|U`+vR~B#)b`9cJZOQO9cO~?#RNGtDmR{HnmMsMJ<$Dej+yU( zcUSQT=zb5o=*XughdA79;1m54)7Z{YO;gd+xx)c7g}xWV}z^%I})3ZbC7a2)|^BFYMkk zL7wRIsA&1y666ze=aZ9zAR2F$uTL(n*^(w${x2CxF8Q!qbv&2K$9By@?_Y@Y_v1bM zlbq_YSgtz#CBQ*mf`0l!EbWGHJ2aUxC<=EyE`b;=$Yja4DQC>?GHg?IAg@!*wEg~} zu@XWJvZUu^0#{rf{+s|W;y~+&x1gFni4$W`N-G5=@x&`Labvq7K~Ao1H;y|F^p*@L-`f(`+wy_y4^ZBhgi1;5CtOZDKEv!fi-Y{90`>duYOX z?^wPzeN2OfeW~uD*b9SjouhK`Ml3ons8;ZB$sh4+b*4{8XOJ~6=LXG_SDLHgK7GF= zgr+)-u6`-WnCa0_`T4q0p9dCPCAXbGjEQsBlG)nhivP?D=N?qd51J#Hht*&vIe!d6 zB}LfxR^u8rCCe#}e#|b$S#zz+aqD0p{tnp0#KZ}N&3Nm38=FWD@l+Isb2GwgyNmP` zFf|PFQ67;%FwDf)j}AZ|z9A(b_a8`_k2-t8D8A^4zuztJzyIx}w4b)b8<5<=1NFnTOuC(vdmVZ|3%NYReE%0COD^(?Nj()rpMr^ zE6u`8_#E*sE)VXwoqXN|8n4~l7X<%u=v^wi>&)iwAO^qWc6>KRt-J$ZpYUuSdmlY43mKI!8-L_$4mZpQ8e z20T3nCX*qXlB~@VPo#tDM;*m)pj4*F>ON=_V!9wTuCmj2`(KrCdeI#8e9iLhn{{SV zX_~I3ca>(JYf$FtC0EbJbV8*IKW@GNIEH*>>iwAzLglspg-`oo)Q1+qa-_Uu-TyUE z|Am19+sG(NSnakQ3?pm}kDG`$>RRW&9&l<1nC`YepF%wHj_0yWop)|AE!7jhX0HLJr#c*A4RL0YFOJ27qy^ur)9TDr z=*sJ8o(C)eI9gRz(eBRCySxjk`l2{SBL3Xl(oFDF4U68b1z)!KrOuV;F|4m08JY|c z?9&E6ek{7BCBOz^>S5?vw#;kFzozFf`(>RONNNP7(`~6MLOnu?8}L_j&!8%Bregm& z(?JT#3;%h>VP~Ut7?akm{~?pBvdKNg|Czt@p8#}IxN6B~NcbJv!M=b6T|?r80mm8HwMKS&$_#jg8Qa4fKQMYRI9lV`nR&*&j~c7nA&& z$~h6_hc-*#5N%8tHdFJOuH$2NSa4oc@{ZNbSm&z{FlgEtTs=_#Hfc!9dr_DVH80e3 zg(WDBRgkBkW07xBnw=archhm5jtG-!_PGZmT^x23$9!-WI zcu#*`T4ZZJa7>>7Oeu>71akxvh#+&(;+e%Z&ZU>x25_d{JQUD`A^=K)DSMN|_Y3f_ z#Vb=m-#$!KCYC7tNvCg?fwC!@-ziAWBC}}Yu|U2?>@$+^Q@oK(GR()4G;1F!h$EEL zwWE2rE1ZX*)UqvGl9XfXEy||`<+O!l>NlwTi7;OSF)Zm|G>U;cw79sVuwatj;JC)- zot)~dp?};jX8y?z&kS}+bKz5o+Hqgwl51)9YE1i5YzL|1G3St!E4C+E{zzY?9U;it zfCtLiCkT4Xpba}Us9q9ar^S6%b9z^ZjkLJwARC$cRn--x>extujgGF(ehRPFh`l`Gr?Zyc$ zd291&lmQYRtypMxw|w<_vrTscvoPDb@KlEg>z2F!?4Rf7t$CvGPCy`mlKq)`as;eD z&QHzf3`jY<4Rk1Z34FC0r(vER1jN1#+p>dTHXg>`TmTb=yvK4Jvi!`(rVbAGENa3k z(^~|$F>Xx*miE~jJU`MObU%g>vd(n~Q29si0>|5bP90mkGA{E)&U+cWk`iMvu&X25 zF&Z<-qEH_k$IQGC^D^=FJUy6c(b;<{C#bbzk*-qWc=tSIfPJ>|g-!LF z?gOJmKlLE_#jQvGkWlaFuVwy?n|7}O6!G3_n7--1jyNNS{*&Y1ojZepULaTfH7hZz zvP)(CvrW4}RYSDSZWq?cspLe0jk;u8(|9Wr;XoFu9)s)|={c|cykaYo>$lshw9GPrXBbOZ*&R42UUGRux5a6l zrj5DQ&vGbsKHMSNzWzFs;CW%Wkq`-48658)<2Kx$Hek;wBj^?v~6x2k_825)dv}lXD^EIZuZqXw%nt<<(;9os%Yy_)HLv zkDO2gFyw%R;)%T{XZu{%RyG0DZW4w8X#9o>#`(#%a39QFm8GWV&wcq=;x;+}0H(NI zMUQb7ZQ<_u$*8;QG#DitK@^zK^C&-hq0N1!*N*>{M)1!TgCZgv$kG9#_Rd zET>MeFB#cA1<7uHZ70#DQOyoXo!SvVOAlxLU$AmtSes+vBcCnu+)G#qwkO1GyQdLb zanD9$poZ(spV;MTi9PI-9f5gKu=&=cBY*0tas%PL+{vo3S|{LcV4snuxt*tZy>VEn zrh!MwHmBX*uJ=#Ir*-$fo9B*m<}CPq)D3k=*mSB=2B8#3eyC)xhaK)xWYGQV#Q%cL zJb?7uv}ooq>Po+=^uPZR;7uO(sEXcEzY|_6k>9Q@On&iz`~9b;ggZCTW4fCO7g7|< zp0F7#O-2h7)ZpCQ_S&mmJqm>LV+jPn$O6^QljGJCZSKM)jxC9J(=QmzeU?0#TW__; zeq6fh9;NDP2wl{fePf5lij@U!f2@5ZkfC-b%;)O;aAU>*pkhOZo{Dr`8CdFIZ%-t4 z^GKRam2ajeMAyp?pV?8|SAf$$xaCJ1KbKq;xuE)zMdQ{287atXPJF#j=d|>%iVg`} zy$6OwyW@$6Q`J!tn|ti7?U9UPZadYhHWobwtuI{m|SM-W@=(0ff;c>$hx<_$bVdM zyfF#Mu=yh;w8{8d7#X#oo*BP6Iee?x4XCs-MFk1R<5-JTWL%>j?9s`p9uTw^Kw z(ZxK$JSF&2a-uw^i1Ft4Z_K|scT8q+iYCTy)iQW5TPqNVX)~jsIj}$D&cK-YG+ftl$!JpzY@|% z-=qJRpQQ7w-*+

&|A%>eb zBmUC$G0wCs|3X&2UnIuyp2dWQsrcnz>hYO<{NnmpKagW7Jl~5&Hg4qw%@yKidh;#g zxfTLrU0j?(dcrcvuYd7~hD9qQsi~R5o-@ndis}!1C-zgfeUo+&vqUO}wb#Q59|qDA zz&Dx54~`v^k54NmKcziv^zQa*+GY8QgujqhkmT~84m~i6272V52iK=N$FsN=vQw^o ze_>(asVBRlwUu&h0#`pwX8uX2kr?qxJ^o-o*P*_ne-^ubDI3L@Ip4JS{=IuEevS2< zstS9$vDE4$&c4HkTP)bp_oq{X4=%57y#fUC!R(&+Cjw%y4`v=G`hI$DRA&iw_R*{d zPxzK9zZ*DcZM`|`rp+`8#@8gwiMc&J_1CF*c8uvFw`yC*a5Fh~s{xkq5_BUJV@H;MJvN z$kJwB-Q`kUKF!B?LeJB$cu0yYh!m8GlB9Y3L;fG=d5kv#(pt|e`_;sL^}Kx%7c~8v z?Qo?Du&K$sS~%6K;?ApOqlP*iNlsqLoTAl#4WWVqEs`{ZfmvQ2%XJiHy=TqSOC6rt zmc{KSAzt}WP3nX7;gOfFMSn=MCgff@y#;!U5+5c^&2l{(ZkG!U(y(9+X5_O>CTn`` z54fLy=nyw*&!j$&34-Rv4NlM&| zOx0g$J2s%729o zrykLR)JpP?hCFsn=^oIdh#KbAUkU*|^5;Qg-n-o^U_oOGrAFI*lk#jYfFH^l7l=KY!{+e?m0)Nz zcY6qBrZ*P@xWqJZ)0a^WfI>@XWeMt0?2q@*^d%C_>}ul~(!Wp2{9&t7UHLgU)rAa- zTTuA-<>I*m7G)37;bWcS!)*Tb#q-sD^?zeX)8|3p9{*iQdCM~p*hfZ+hV=w+pFL*{}FNL34uM&X;sru^P|dI<)_tkZ^y;zzY5JvbQkH zQIkY-smu?nLkkPaW6p6U=yhU7)0>IQLAQ_%yS-Y@*6z#Y?~)3BJ+!#Q+mja$S5Rft z653Fyf9IX;YwdI_)Kr)}Kz5Dr@VU4<5v1VaPhN6Q5SLvPM4eWcdLMzd;}hUGYi-8p ze76?nP5|H?K)m9a(3>xuXuc&o!YQD9feXis_I66~uYB1DL}@QuUA|U&onh!`zR{P> zLvi5yFku1)wSB^Lkg*2yUel-6%>JXWbT+}T$X(&!w%=jNk(Z!I7!d!iU+?QTm+Z)_ z)0MZQ74eBCQ_?&ONqwW6YC5jO7jqDRx5}$^(5xAD6WWc|cza~fNYg8Nq*WwqlwJIX z)Op_2V5;TkSxppAN6xQ)apzIMmj@-Va^V6!e06^H!ypd3d?|TGzj4noYp(63VcG>jidlB~;`AIplUZ+PUboUznsy ziI2>(#Xik{g1g74tb7kTi5MUeAqt+{qDz;#pshDEcvu`_fNV{WcVO-~&nB>;XX;ieZldP;)ECXw@==?au{8|TuJ2E2%8|pz)7Re6m zk;YhUgr?6o9)ce9Kclu*;9%k~NP!d0TX6W^igv4Y>UGPmkUw_dEy=GJF0tHD?B@(0 z{(*dX01Va&iFor?yQD#&8f!4VmL0^Aopi;2$gA%dp*XU>Q$#nSg<@uILSo?Nh_AW; z2N`)8kOylw;Q=iQ6|i|Gxp@yt9=$H|-L?-}vMZh0VtRccZ-TYkeT;4#Jms3rQh09a z*d{^+J3_l4>xw{x6pq&A9rdK$6h5C2h$!bMUfsajx(JKeeTgva;rQd^aJ*?;zoF*wuOy{)0}G=ZJ6t+uia zu|p6p94`a9eQVY4*wjIchGglDgjZfXhJei>s_|19Rx}1cO0qXwy%1}sE(BE(hrn|` z2X-i=HhC?~QsJi=WmeE<3qJ|qQ_w69dJg4;__H_fI8fVSFE!@@phRmJM3qIC{9Z7O zRMD8#NCu?bBGE`d6F85=;Hcq6fwAZob~9>|RkmpVWBChVJ@gTjuZFzUl?vEmw3@@Q zK9hx7EkJN{aFsb%9?F*~j2=DHX2_V5>~$D;SNy#z>uO3iXoV}hx)db# zrn1Ofe%{hV4|so#K&_=d?LBe#iZ`ZXbkf#0dew99o8V%vcf^k$kOEq=Vq)d{YmgBa z0}~02oiuat*Ss;+j1xKl9tqP|nlTXboid8!fKfM7#9j6=9!4D4X56Mx1DqdrF=c}R zEUNyCMVdpq^2CUU?i5e{JTWl3Rp|-rS;$sCpU-Jh%VQQbv){ya94>>(JSo7JArLhv}nG6;DrcPYj)@ z39LN3R(jz+L0X@;2ZI8FqUAjI-P!YsRU9zUjt&;s;>ZC%0`#c7+{KCOpxhYq_bs{FY$%~Qh0S8uW9Wv8!l zUao-G=F$`0;g6HE~FrCGD3%fdBAkZCoe+EN5-1W1QI=$c6HF$ z=~mvx5@Z7&sTOe0;$_QZxBbJ>y;slAJ~e1bL++w*7C`>@p52g_m*=Ucv$klK4z+Jr zc5sPb1+7Lmh60mO=ZJ$Y*~9&LKaeNlcN`^%7$lu1)LNZ6u`O!Gld@4E8Mj=^hdkhq zB$eRw=O%h(Kx*Qwq#*gUWCP7)LyHSj+n5EfQ21}|AnNla#Fzf|ot?Ji@cXhwUn3zc|CaqxEuMQ5}c7%0FA z8$*WO1ClQmg7Tt-!^+FCrd^rmcL_SRaGx0g;}xz*QwhbPrI^TaQNtEKt>j;}QnMJC zr#PGHDD3pC>papMw{OQmheL?p%D>Zn2M`B#k#=dv9J^v(h5tLfAo-%gCyX3sx;MpUunC+LvA)HO!4TUWs z0xY(&Y6vNa-EKqE{||d_9?fOj_K#kMB&DQLWQvkTjU+=7rO8|>q>_kCO_CuI$rz%9 zLX@$AgbYcEija(vDIt_0Gr!N#{XFmUzH9Hb_F8-GfA(6h^{l7v`>yYBUFUV4$MKnt z7|&19R*iY2(5A2iDC;OT1gF1*_Cp>??T7cuw`kd*G}fs-4I{+xrHx8_Ncd?bRW;44 zmzI@14ksk3t8j+coR#DgcDUpViRLjp(3n0%eynM@4Ax5>eqd_VoBdc61D*ASGZno+-H110BtM ze0R!b$65(08mU{k?CCL7`;gZ)15W{y`%Y$${+2LqwVV%l1`z zfBV!2UDSACZlmYqXwD>g?QK$Dy8{2i5+}pVUD*pWQ3AUFRfZ9kQoM1d#(k)zO;Ywa z3y-QdL_=!v>(WqU5EzLu4=d&Qt7~c~pZu>RR9gvTFi9ZM)RK~cQrz*tqaxS4r9c9B z6pc9hXeiTB$oU^4#+51=Z0$=BVuMm;Uc^Ueq1Ga{4!EMq{o{xaOMMKFH3i226Pq8g zs_53L73VkkDacZ82*G5)UOX49@xj_w67_G<`nQ2-8@8eV#Avi#M>5wymWacJ`VRz@$LZ)JFghE;7-_6* ziRX4CBt>$88QU#8yX1|-xJqy>u~1P?>j*jk^02^BcNCP)p^A+dvUrtHL|iu-Wq=%Z za+{DnLh>pfs{Xgw?I<6FZmo;(y?=7CobxtR0pq4{Q)HF6K>J6jDF1RuAR|0-E$=ef z6pJ44SF=e2@voEx(ljBS0wa2bvDpJNhs|^b2`KP*LY3=+tm{13ozMwsWh9L)LKssn z4n-6}lza4haQ9&-3JP#=K0z)$ev7-(t&DricK5YKl zTqmi?qS6DW3eCMk&2f?hg&7mZf7rC$SIt)=jK#_ zEB(*0^?W~ytJAb|2;~off0DAbotl3hHlb3%_Rdc1_eG*8azur*r1DW}((+aV|c+u{3yFL~qE`^DtY&GPD?|k`gH6 zURjxjySUCYp)I}a+FU3DfhGj)-h_PF6Vg)T#Qj4G+7yIXs*iL(%{C}2uGXju$@P-q*e zx^Q&0ZR^P9?&X^aiBo=?y8vxpwBrL6EiPSWj7pFK#~>b_b@*_tfUKo%Rnb1>h7CHM z_X%-?pjZ_pF`u~@5;k_kIl8x&8@I1Z+zq7|g28s1qv@*r@}bZY2iO&;v$5&*$u3v! z=Pr8G8!{1a_in=65R4mti;rP+>eRb^_J$%TPp#uBKTNi->ny`gLihmoz3<1&29Pr8 z2^yNd%gFbu-dIOJkn+CeLU+RX5eI$Llrxdf2mzy(&t2dCsE!M(0AvzpH=@6L?_Qm_ z?GTe$-M(cJp}CZ2hkE2Ue0)@6seXz1h;gRVn$CvaL6THw%9+YZ zR`kPe;@53#SH1d){%mfg%h(qbR+5*O)~vR8a2+QNzV!O)f5nUoSwvX@|e-zrBK}Xc<47&wpM3zl+|Cd-C5u$xh4v{UHB)MgDh<{P)M^ zf1jHF-ADfawX+5$Bl<_!e<#dH4P*O?7F-#;nyrnx4}Ppyp!-or872UD1viiXh53F( zv=9Vwkah?5P19N;EX)8{hn>~#{Kte7ZGsbPcJ1A}11d`>M0|hz_%XI`CY>%0SBK2^ zDJvth%irVaj)`B7I-Sr&z3YkAnyL3s_z;2&Z#L{IJOHT!K;I2y_Cy1oTUE6XR3EjJ zjzf)8QRnTlIGS5st=BGl?9?ev2uCO}2-uf09LRZ6d@G!wR9jTObk@oC!y1<{9triu zR9OC@s?|@n{gg(Z7eiOjLMu8hYmHAmUN95iYHX~py#?rt%{f`}$HmSs zt0*ZL?UJ_a>zk+;(I^L}O?3kBsb<5~*in{iToFIrZPv_qPZhYwF* z2en|d(&O0-sANh+2HHoNo21T)Xh7+Su5vv95O#)OboOoUN(cRd5GW%8y(5Y5?u(2r zpKV-;>rFwzpwwPfRY@Oi19r85q~eI16B8G2Rp+82Z6XcbOcSsL;HI!*(a;N|7V-{x{vo4Ok+DQNW3p_vPJna3RzO4^VsI#*j7FtE z0<5NGPa%$2>SPxKfa?NW@io+MP~Ga-b%UJ5ljW`Ddh%tqBxF{nFgO9^ijk@5G?JQQ zmR+vX=|852?AOQ+(r~w&@jv{)U+8tvH5h>e1qGK_bo`l`Z``Pe;Q~y%4rZKz2d+ri z_OR=n?ywym7FT?3g2d2Xt6pK9LI4M4JRoSS;-U}mzGLtFL?jMBUkS+@-uOBim}6MD zlKhS9lbkpyKU)>u>roSvsnD<_etANVATH|^VTVI{7BP3MZEas;TwX}4<82HOo6G#d zA^->6dHM2r^RF567>Jsu2Gu%TzWj9TWkqYUSHX8$WDA`8 z72Gy7fht}0;Ax?y1C+V(=$SJ!|Ai=HiE7z<9f5R!;OfK89HjY>$ zpz73ufHt4;Zy5hlKjpks2nGYB0#%a-3=w=k%+IQLL2IaNRwh64M5=PESIrCSsIr=< z+wF(`oHPczN1yEeeGz1ZJbJt;vlR||VQdI&_&7)oef9P0V&wPe@VCrD3^1S=nk4|7 z59B>IUp}bK@&So`7K}Z!rGhRq?v{jp#s)BQn4;_h3NaFVGmVl&MnPb)4oF*VL?WXx zL`2*{-85h7BD62Uv&P&5+6~zJ)5r@zT)^0NjLG&rv;KbRfqY4S^`pUu%@BD4Xk}qA zyD>hn8wU|nRPw%~Y#yEf;0v!2zl*gyBfy+#*nnA? zMl{PZjtb-${p|=gQ{$oL15Y58Qy3*E9b7^CNi2?EEqAAzEbrf1Tg77_lT$;eMDS1p zHuKN*nO+pO&+tI#DM45*BMkgN6OL{kgnh9ykO9xb_WhuD2zOSZE)-&mX-GhH!N$xe zq7U{VkGjcTMW< zJv|VZ1lC{8DA~Ku(f8bVwG(t|7bmfLb-I_7Jj2Ks4g2slt8#on6>Lp|MNl1+vxzR- znuJXt55g#&>|ijnHwK>qfp&3scRy-sx*XmFBWzJ-Js?0Mrr>n9wuc?x4I%(8FbWX( zB+_J+-}d)9s(@nNC_rKv7k0imfDfiHE?||(Rr8=33LH0R@U9o6zj(n3eH{;$?-Wcq zrNAD&Tei3D?Chq&iNn)l`L^dJu4P{KNMIIU-?@``GUj2!Wru!=lpkERPY=(er9+A~ znCZkWUc5*YS<`dez~v9#z#0vWDKEnb`#nueC0F8df!MesK%If+Tt)8SN@3Idx`Y&kjdti(3FHZ}$-Ypm4)pZ5)FpXh{coW>y z{?5VO%F1@`zz~Y&WpD^U3MP+TFVu&#w>h8+~_ZUVcT&* zB=E(AIOpRyx7Zo-+qeSA@=T6>@L<~C^c*W})|%315ITZ84v(X>O^*F4Miq}F#|J!%7R6OOx#MEP*FL%@BJ7gyQa;&;qZ2OF^=<830xP}Mf9$Kz~b@uYjUukg2H=XGKp=B z6Ov08oP1!)QXF%z%Wnoq7J20U>D<}QRuDEl=mNdM!+#k|V&3yYAuG7mgIG0e5GjhU zOOnHM{d!@;7gk0hQc@b1Db=1^Tw+oaV(E2F5rI z+X~mb^LS>-zSCtAE){Vyy95T8CKC!H@izTYh!EDqN551yppgp2y&7jV-aln=9u$i0{zJ-BZf-SKfua#p^4GI$b* zETvhgfqIFl@%OD^7Z`AeS>Re1KiHz3TZ)I(*y=fU*QO_NCs*9`l?1>Ym)8kJ?p)e# zp)C-B|G_w;W&L}+h~$k6yRhYWpO)q=T|fm3gkT3QWr3{3L48G-vnhj%$~NuRLpMN- zW5(V}<|vZuBz+eHoQ~F9j<$}nOvf7(IeO2fIF7V03#y&Uy(q9%=Bcd$wVxm{B(u6( zT3jjijR8kiy0YA_5?c|?MSoBU-Rv*iQdadngWxbw#!P{*C@!lKM}{I=1b?i7?`|JH zFAfcYfr;|0-HN)(=!x|t3fmi|er*L_>%NzN zK#RhGm&9&Np6cU=*4Kh(VZONwIYLnGUF3kLik(4(W8lys)PRu2VvbeJI^V+w5B>-n ze0EdI*b+7^ayY^NPf=Pl*%-af!^ivg1+m2y2{HtfJNFOFTaVEctc-c>^Swx`Ce$A4 zchkMv!^707-rf^aE`#J*mgqK9ozpsmUu%51HcQ9{&o^ayxPs_pzCjd0QZ*P%H87C4 zzq4|GYetP;>ELbos5PnK4PbhZ52e&TfYP5~xaP0)pg;_Z!C$|adt-OAK)#|6r1f=K zrFf5H@;{o5TZ#R1W@{9$`_RMAGn+*b5?7t<)&4uxp$Z~D7|frNgQ)bihDY#>y>Tx= z@$;xI76+pQ8agD!(6FWrr#>nAex6*0V;p-M!$3}?(<+Y=i5?#|-MABNjt}jV6-Tma zvg|jL8|Hk(m?o-Ta4Ye3ODDZ3J^>U3rp{B4G{Zpufk*$=41WULuHJ2zdt0tAY`LlSfLKq^c$7Rk)?ZFlMcInP5w)E=8OD2C z*a3dO`Bv>}%u_y2{Wb|!U?b4EHA_Zk=EL__q0wOO1x}JW13SV| zbXnLpSTNDp2djy~naucA&rOaA5iL10o~F$x4Cx6B$A z+#4F-gCS`kci)av2fH&BcljuIL0FDSGzygvc@+QZuE1uIZWhT^2)tB|?d!nFQ5t-x zWDTl-qRJa{m$Nct?S0#IyR*F4{v1v!JSO{FpSTD!JDvWF_qPXJTn!@83H(Lrf^GU& zrTye>IEQ-fYt|lb<<5s9XpxZOVgUgwM9T*DvXES4OcSmuGt2K_gCmVAxacgStA6+DBz<|9g#GRgbel7 zHg{Nk7#g!k=pC*n2)4} zQr(<>L-c_tU>?KUL9Veu#|oWe+F1wE@0xu> zgPvM5D!M-nk$jkL&B8!!H{bVt*X)aNkS0w*Q$|!T4x-+%5ulN@u`*=8Uo;qGSu&Px z@&M9JW<%SfgkE68>^PXuWhAXW`?05Y_9HkZTw(kylL9U*8Cw{%|1=*)g8&8yD-%`t zgFVTNpXc*pTzKg9HPZDL4{zMAr1bgENX6Nn4TT-D5n)d^F9ESlPBf(SP%0#zy4bmA zKlIozJ1w2jDa@njVzjA|=z!tUev48(-+^3wIGG}0b-A=barlh835gI?n8GHTja~2)1N94;*qW8-g2CPYqI<&8D1@;^T z+H*F?BUkBE)ngIn8(;#>^b$yuX_P1sv?#2VrcOw&uPKWnyHAuXW@qalp!qD%q?uRq zGLK>Q4yr}9m{>mgsRBKL&Yu@3O6I9Tl2EQAZ?)0T$GAOLGlX~TE8Up$zAygrS^blc&}j`{!dXkM z5$2fV2Z`hIA3p#b#NLWLDUI~{-EUJ)(l`nx7fbu>BL9PwI(q1ksEEimW4*cj{PEAd ziqv!)hJOZT-3A8AO##XPJ4iNy5W?L?T#aL^r>rk znlvT!_hr7qJ||tRO>47Q73P88Zig22tp&^Mq4#3;?p@N&rkWmKnvpj0e&CTFg0l!;wHbLn=x#VHvYPJH*-hQ1hC0%^o=2!c{81J>{oWMm#YIyuoZNih?@Q_(b1 z1To5GR~SSPCvxBF15i2&&WF6-2rVI^WwfipLqa$)g4EhoecveRS3+E9N(nSquV87T zGxfnZo!!vRwGIv7sl5SZ%HMQ|n$&nr#U^Z56n_nsH;vbUG+QsW4;w`kq-cyYOa^gv zv4~nhkHgR?Hz>v=*sqd!AB=>f#&8xa+DzHiPV}$5hWqij0o^C>fc0Dp6%a|Nho<@o zTGc(b2z{{%cpJsnL01~H&KndH1^6*kY;c7wgi&(I8LbCE^vL5&(>(t|!_y@Pq2m4! zCVRk0Ss!%_1)MUucn^$LSaaqEqIJ-u4usd}O@!sxw0ahm-gBNXPL{-e1=l_!ARs`= z`C~BfG-X69)(5k}g+x(n(19d<%3RElg&x~4>rSB^D}pP|#!>tMaLd3&m+>K&elPcG z3xKHqHQXXBKkD!Qm2H%DRXt%N0U~nk400S3AMfR1CvhJpv^WcMTG`mpOqCZ|s2iHL z)bK^5uP^)8zJ2FiM0?NIOn2D46Mf%X;3K}>uQ2Y4d6X0pf-GT;JL2B~{vOP}OcRO2 z{^I$m|94O(!EBinu*xVxZta!ViGP767dU6dMQ=EJI8mqr2tiN*p&jZ-QhBayn=HLOnkBt^Dg>nM2 zwmFz$u!*Sz&Q9{h2(J6k=p7RbXvo3lMl&|J>^=f6YjyBdQ0mLO0C3*P4#jQd^_>Lk<%V<2`tTXF zdA+9)mU{6n`q#hMB+e)oOh6gR031b78mEge5X@=NKN6sc^h|97)$mE2N^g3sa31B4 z9)R;m1TM99X>6as%+Ob)_FN$vV8i=^PQjHCU`Rwz)lSE9oNIW&%8Cc&exQ#K&~-+U za8+3JARixvK!5bdzhI(uxA$OB$W#3H&6^h>Y&-vK4r1MS*guxM*GnxpkXUarRGFGANHluA?QoeuxMHD#W^N46c!am52$`e640bmf$A2M4wy4lq)=6XASh zOpb=Y(O4Xdnv^U5_&V5uP-2@|K$+f^5?aRaLqICaj=_DNfH9~{L(@a$298%@A1+LX zEc!io7>R~Qkg9`&RFzDk@LHpQ0}jd13}p16Tizn|0N&4BoXun|CGS5p)H@I;LFq)e zK=D}y@NE#+M|32%OK58Ifu~WI8djc-6;BFwFq`l{AR1~1q}4Q>3Zo0QwMs$!{|4_( z2{-UjiIB)INq;*WT?BVP)fabn{^nQ!q707ox7b=}4g*cQA}=P=zhI{TW90*Yhp4EF zDEg$;R>z|1fD(0UE)2mmznd(4*ovDTz6U@@1DYjORZXK{lOX*fqkwxHE{w5(qLTi>R?!S|;Qp)HA zwohMz@(-~Ud;@n+lx>9|3gR1h#E-#oO9TuT-E*fDV>uL zGXD7dXbv_qFkswwF+;_4ZbG(k@)k0RkntOLl{CyaMdqt6{BA4*LgPAWN}6>+OeIYb zLa2$bq$C#^8&R!!!@zUBTzeR?zcMm>xGM)Tb&F9mk?#SxG9NYoiW8-Y@A#qajCA8! zSy@HR&i($+{`LkDD3qH44TP*S_LN>ygg2)s55N&oII5|2C#Mqj=;ezQkOxmgZ|UPg z)tPV7?6qVnK)T@a*!rmJmp9#n#55N!D?;VV!(h8FPq(=PLXFh+qyj|&L#|v9&tM<* zskp(Nrsn|B;1S8~pq!0)x4ql2It&`k3gA>4fMI}#;Zu5)C!aU~*(+Iz;ijJl{e+B7 z@k`Qf0|=UhjUL)@s5acSDSQqcgBvX)G-i@k@(kJEX#rin&wU1}Hr>`<=2#alW35!zA;$N-V_3#2-v zE2RgP;F?2{OVf%VqPFR;LkEuB5w7{GkGwnjZv|zm>_8t2?`;2i96Q%ga#3giP9eq_ z^fY;JNDAR6&n;eBiPDln#H-RQg_M+()FYLzkL}|vJpdwqA)@kR9P6=pKq9+{biv3n zwH+Iyk4!&*r85!>koZO17m=ITSg~Yj#uLUlu`!N_z3cC3P>9gl6)lHW(fhO=uSPwh z!SYdd_Q~@7^D@fHqNL)q*@h<{ww3EJsUudN6&pK&OTdAtMW+-r3>?V`0Koe1pFj)V zZ9J}H``EGG{`2DyHltUT<==ni@2~v>B=Y(1KmGr2w5G4e_A#wtXtGEgfNo_tY%W2` zq9cEuf;cAfv6DfZ6wT2|8eG5)H|pZrq`nO)V2!atKVx@33m21dzTX%e^bw^@eu{%B zP9&H%a4`(s)@*btt@O3KVKb+1LF>}*Yl^t-^qhppABP9~47{{_9egL}M{Vxdj`Zup zH*BW&&2PO{7pJf`cw(0S;KeKuZ{dGOqcIG`>RZ#pp99I`Z2EpQkQNt4(t<)m@u*hFL(B{CO{dsIh0|!3C-VcU> zhv`qY=*sy5*#6IjMg#NziOTom+MGFxf3Md+{})(`9T-77J}81XUHlpf3pJ_OqY1vl zcX>z3c8T$V^e2`csCr@W1}oQpFM9fH4&`VdDNTFvVx)iGYA7uAM52_ohaV>|xG=WRb+c9LzeRP*%$51BMA+ z`F-F2<$ck9Fbt}f8Q)4bQ&bQkk`wPc?f$3wa3rQsb^7iJMIic}hf`4-I;*h_nN& z_HMA1p&Xuu4hgc6vXEOS%&7Xn;d2c(K9J5dlAHVsZx6>_J_*6uyXELVU%~G{2oul^ zQ50VvUScnHL&=smfo=K?S{52&i18$SXM!g1Vh6RHDMhTQNRU9}AjC+>i z-;MjuIk8SzNY3-!S9?F+`5aM^JB-R7tWQ8~C3a&+^E2C(vuBs4?OQs>>c^kMKVQuE zyr6O2>hY!>gLlMBvptv0m@zSU=>oO%t!V(BDBqWH#_FY13J3xUU~aR*F~0DG*?XAc zR)NKdIU9~0*>f@+Jb^ALYgn!m4L4Sm2It+=ymG=GV-e zBfx;(;_G3@u`RMcB=DU01RKh7AiH&OKU{Dg$jZt!IE6ThiyUyK_e)tCV?%$2TEgQ zDI^JfCA1?8Ah&vr549GE0!C%!<&lgGzyYO)75B<6}pl)_*4_i+#Ac%+_?WJ z2O6^9o?%(wwWGf@s_FFd`}aL^F3o$wV$vvWnN(vVAzQ#}dNs#axqo9oA^TW%-~5?h z9am)k>DAi08gXysvlNbMq{YqHLK{-+6BNVLFEurgUJcT!CouNcA$$V|n^qNtHX1qU zz+5*IHa5od%1)h_ePN!cD4TGmLdf|tyr-!X(*XMwWSoD7nr2ZI2&U#2a19bOlF^w~ zCZK-kJZMxCMm_NkVp(k_RDwz+n0~2d!8758`$XOjJZ@ZVPKe2alvN!rLHac4o^mM{ z0$eYX3*bZ~pfGPYxC08}44<_$cZw0?K~EsQ zKr^V(PvhS|E!gc>$kX0@HGj^sHImH9B*AH*hTo%t^e?x`a=xdU`kKqA+0lmo3 zgzn&&F6QaE*bKWaJ_Tm+c$3-@o*4Cgd+UGhRh`l4nK*+o>M1JUPD*$<8l}3H1z9{xKG%j zry@U&1`uHPFhcGH=5_gfX={AW^!0zjt3V(`m8LNxOu~ekuJgUv?sApv2i6&WoKt7n&kdjFOY^5=Z8+tlcgZx z&(y!aO9l!z4ShZorsqG?Rc)P7cfe==8hT1mUezRGsawmNvLm5->lNgR{g;524D0nNarwe$arJ<_JH0Uy@|~gxDQg7`cJu4 zl1vpK*25XPtMs0-V~&l+MGM%zFD33k0PJyw`eq{&6L&m%c{pc@g+oTWKStVuMm0T` z6cVzKb~A`Pm@>9&G*1M!ajRe7wlU6NTTpw=JEln5S}+0uKF!L{GH7Q^&Twq?l}5@*(9VmV4ne!E|2Mx z6mUcRHj-DUmYDa)9y>M!h}y-wk5DpAgx9D^-7O7>-jMKle_*1wQV4A+WaoHLb(-yq zho-!ffb$>3VF5&^8w4Y?w7J;+*5oy0rbxe|&47FG)S~RUqd~{`0P3 zo3GKEEuYpUSxkH=91FTQ=K7Xdi7>rKX*O+FCUx|_CW~!TyqTO_)FowR5G)j0kyZS$ zM9r_z?`)Zn)YL!QhZg#H-~)Y0TD<4Z<~gbsOT+qA0buw!3^7+lcjR|5JmYNmre!@*MFJMS+SP&^%q?SzQ8w6On2?# zoE(yRU(&pL+~R|&%br^+8m0{oO2wbyRqmKv*k_7}67$Z=!3AK>!#YC)pB6FxiMj|flLu*b}Y%KfjrmLIaSWdlN z0x-93S@Q%Fn4QAsTY7HtL*>2Xh+n^>-gZ5mC1lf@p%=C}E$Os+G-ug{4RMQpNX?L2 zCMF<%xAl+Z>R&gx^}epb#5s$S#CC&a0`h@)x;3g_jCU{AJX7yl_~6FlyIb{^{|f%| zeecMh8w*5M9R>@hvTND6|G~_ez>L4VQX5X^2FBC@wz4EtJ2t7_ZhhwP#OWCwA&}Js zhHl(LB`7;e1n!fbN z($plJP^JU&>#$J?Ano!OW_sZm!gM;^%)h%AI>#Um{v9eRtPHr|Cr4kg0z%`ZAzO%w zqg@I&DH8Ba(*&3*WeiG2T9S55A<~-OpN((&_KiVR?p6tL4<|D;CNpj~0>TMOSvnli zk0kVN6|#9WE}xlQ?Hq-eXe=v=DWe=EECtQ?COCgWHkk5~^q#ceI1KG> zw+Prm3lap?n*7Nzmo~ST*#^3V3Ha~K7e3>tcFsz%H}is7TIk#rw^j_Ud45nT|MaP< z4~dJ$T!8@17#P=-^ZXGYW_Ic_ob>Vfqy9Rq! zyfo?D7bfh#NLXC@VBfU=%`U#ebM{A!pQhT9z4($H5_luzMI=+#?9$(Of)VUn`C}oZzwScLcv9VC+eQB`U z8yW|X?U$e)ukcyz<>NDd&_7`EaJf|PyVU2SCfOVg;~((6*r@x}bZye#(pGfNBIB&e zSVg>@qkcR2G>qB3OsO)J2vuEM=GwvC+bbnpsJO0>ew&>u` z9b-W&CLio5ICsK!%fVrz305%$r=;_l#y7q{E;@E+u3d)ca|`#MuczkEU$Hc1Ma~zC zVxPdZ17S}XVeuQwPhFSj;aj-#@x#M9t~w^A}im}a9adYeksa^jAtNt zKuS!ZtRoQZtyLcT^$B7H>3Ns;&*@&{b0)7ZwE^qFPljGaSk6S%VQ(Wz|lm{oUPKQ`A3 zs`c_*Q1bMOhM}%SP3~uXyS-Y&&+9Uh@^hIli)B5W_zxb++Zi4^g{22WoPYKB8WIQC7-tny5B z-0kmM8qch4S{+j&cvSNV>yep1RomwjG^Us?NvKWT&a!`h$V0uWdi;7{OMVLQZ!$Xe zTzl!tdi@|Bys^AT)nnAN*#8+ce2769*hWcHOOi@%Aqa z@9=)Z$M3(yi~9+PVbnw9H*AI+nv3=aRH(m~ZjUWRN-yX(7!Uu(Kcrcn;{=Fj`Bge;%hE0E^aQzBPv|C{8$rb83j!C{- zJ!OR?CNhT_b{{o(w$$LyPdC@l&HiU6S>zM--SZS?6{mK_AJa^_D=i)=sacz6cSlgh zLS?w6;Nje703X@RB^RAb1Xwp7y~}S$!9(*GY35%KnVRkTU8S4o$0Ohscp&++L2dN0 zDIqUj#BIHB-lbK~`%gf~?MI8VIGDG^#fwxWTIXl6C0I6yc6f~BdVc;g|5oDRcm??z z&VGJ7cs9z~9cb{+dFslpJUJTvVAH0@r^?tl-oBOioVg%x(tt~3`Lu{Hmd@is51mB6%B+T6lJE-2HuY&n&=m(1TBL?pHjluFeC>c`D#N z&~CeuZK3z#Lz)b`pWs&DH{vn-BW5zNdvVnmKU}Z|4J7biB{XnrPrTy+e|MaSSj4;| zC@*n9Q3|g@@q4Iuc3=kj#Kfp2C~e6-(`s>u(ku*|e>*_D{%(9dmF%(VUmw=lD?r~!=8RMdG8KTnV-PD1Ap6L@mLCJ zF+_NygAaU;R+FrpJAiAF#qRe2;9etue`G8~#-Z>&i%S4{DeVo;J4KX#`QxCwxey%* zB$gZ2i4_2VxP%>wK0|wi=K7$Bl_>(Z0E$yC9}VUGg_@Rar{$izh^Tb#m$7fJU6NFK z(t>;RwYzYZ-JE16K30aa?>5IXNgbmjM<$I1}+$sm7$voznl`$FYh-hKJrAP z{$1fT9`>CxxZg0ERV{^9GG}?poSg3Ushy7z`>?n1MzDwX2qFMkG#{0vxcKya&x_R9 z9#lS))p1ws+RO2VsP3fcuhLR7u`5Uo|i3H!$(S0gYX**N>?`Rvf>? zbMvNuFB%&K2Nxb8-<16kJWlQFR(9Au+GnfFa(|_moYPD_HLhbmi_7iI)EutLOD)@{ zIG=wy^SJwyAjjk%=M|sh7sw|!FnZ=IMW0)`ckO38dhnt8wUlhd4}}5kJ((wV9ejT3 zes4{6`Z7Il4(*NJvBt<(m$6mpI-4cR9k!hL*t+vPtNFAG70GF46&XQqQKzOfR2v;N zePqMEKfUFXdTM!$%FnW``lofToOSEVcDVZ4w-KM3VcbUk==?Z(LOYr>;c%ML=U3Zz za!=phe|?Ubtg!I~*JF-*`DdQUUS{b(f8Ibt+maww)1p_N2}K{)oV#-{a&q33Jrk#o zdF(E#ZqFf_DqEX_gZFHL?#jbjvEQ!&Iim*+HXT>AoU6dlg+Z<3pI}ox_WCEDP*};hamdmDG<+7Cehe`FI%JZ^ zKTEs`uLg~D6;e->h2FrV1Oafq1FJ_#t-Yvs_aS9=`4lyw9by9sQ%1#i9|c7GSYdM3?Aqrn%!jxG z?&UGwel6dAeEAWu?Dxr`WEgPRMwYbyo|7M@FpH7ElvpnTTa$-QWqk`SPV_une&gg` zjbrVa)`y2WDu!NsTLaWOQ}%9}5l!IDU=_h2z-E_?9p(llx>P!G>2xBWJVZn@ce9 zb3CD6@Zekc(T#fq7ly>FF%Y!x7BJve6EZTVs`1ZH4Lk35S);%9ZQn1wWZAOlvD(_b zi?6ayL7TWDC!(W1bhVvw;|qhV*-l-HnB3M++d@u;6py@adNY(F>X2npyX~s_i55=l z+So{wF zPNnsMWYLrhLX2~Moj?qf0>2x#b58ElYkF}?Bw=HJUciHxrt2ma6~j^IR^;oad!{%! zv1$8Ahh4+JR4U zOm0WL;3J37FT^$fCVUZ+L2++o z6F_!|WuUp5XwI5oQ^sTdHx}6l5L`9aCP(h4>(A9zO74))nC-l+dQzosKA$Yd@5&_6 z_ChT&_9>0GFYFGAde=YCz9Z+2&8pB%ZvI(W3$47kXDGi)isS$MY(@8&%0d4 zpLJEvk}B|7^1EJ6tUf(JzHF~(;gf57bqcr&3y0R)`wXRxyG}cH^nj>;uF;)Y;ZvuG ztP*9_u>X_yRB!#kO=f}KviCw0REDFo9d5)9bliw~?p!xs9PP_7-s4#9z0f}j$$XuI zcdZwxY?saKFLl~%DcbmD6&4ETc3Ii5q~Q7dH!7E?c)U`<6Yt;OK3SdF^Gl*AmHXSE z`EGHRn~F-iSk<(avzEnc?%-i(pFLMwUgL?-?}ri^`B77Sn^Cc2vvuDIC;A0%=yiOJsB^OKECYv@Ur@=>&LBqTla6u_nN`J(W2x0 zmzY*N9|4)N?|E5bo)UHGrN9H`fOh zKMmL;BO@Vy{zJQUm+{Z`fPmJc$F1CF>b_E8PF-wpc=7x!Uw)+MkQQf#2Ar8~cOv&B=p68E}`OO!BQ ztrWvUR=4whJC@(J<8Xc2HjfghHIbKj&1$& zc`__1e9e1qu`6@E#Ns^o{)x;xcOaKtAb6#o3#M(p*tmJ+${ESoBW6Lxwmn&cx{~_O zaV}RqHo98d*WA?ftJ)+&$Jo>{ImUj;Fx%^hWi17gJFXk*SgmcEkP}+))@O?HY!2Jw z`WM#p%4AxJ^La~DWiM9LdsN`(v-p0Jp`qZKx~ziZECnu$eb{_lRIXN48Wy-^cC(7@ z>TYi9-jN%}p789zNc{W42S2&aX9wc1%Q8Cx#T`e~(#(XQYEfIO4b9&+W^%mi$ z_n+7xJ^koGwU5iMdaH#G-7cIQHkrG4D<{L5-N>O_FBA9H$D;A4{qqGOk2e3A9QNkn9d)W#C?x7ch#o6 zTlV6* z_DMR`ZfXfx|Jd!@QT8-bw&v8+o2s|_xFi^5*+&hPD2Qxnw>!05|Iy`ar7ag!iU%8H zyB_~dkZW*^ z>*o{hn=52iXmmbxePUoa&s@7`vT)Nc%f{l#u;PiF3j>$i*XS!h#$r$@b>f4mJD*t6YF6D% zM=d@SDSe+R6zFijBJy|NpWPxVGZwqNlPJcLO?B}{~4_~S$y!?^m>8Vtn)+W2E`S91Cc*$34;^$sp<55x)p10t&RcV%U z)U$Oy*PfY(T#x){yQu8Rk`8V`ud!8!kM7+2Ce<=~+x9;1P1n3;L}wi|P)jY$saMRm z8g0pS16#0vUmv@Bq$PBP)DK?XBNGDYMe8;g zWOekuEUJtgs*;j9#V>!#ue?vS&M~v^X|XrIo`e3Skqzq<^k>}rneMY!?P*-`C#&A9 z7h_`zEKH?tl5$=Au1y6ioFPonJ4lNcni@mDO+2-c1%NQSXBji^BPn^|u#aXpq^OWUOB_=yC9n=lV2m zm*wl{V5{Hpq1ktPZIxr;_5%?g%M2ZNWM988s`^Uhf{cCf*6OzMYAm%gXEOe@ss-M; z+T{~i)m|`NyT^91f3n6&>G|z^|HPHX?H)R?V%RIrz^m`TO|e&2kt<(oe2t#w$|Wav z#GoZ~G^3meje0W1)vA5vX+q3nc?SkgGI`I*q1RcKk-S`kT~|*mKU|#?@l~dqJuH6R zNxop^+3ASX8qOZIKL1W#M)om6fj)^Z-BP#t=lt5~_4DmIE1Qd6+UZ40~FJpqWt5Ght5Fvmc9| z3UDu2lacCMt~j=zU%=RDL; z!$dVt56s{{qd4^K)t{0_Yg<6XY3>g%rW)h=F8Hbx%pWco601&!hMFXej2}T zE#U6SEIA`%xBJuM!ZD{MDTgse2AbM;$g@BJ*MrwZp*v4H&=OOzI@jv_L#M5 zR_?Bwu49SDEUF^X3yr>if6VZFN%I|hky}^tWa+%0XOHZw$g3Dk*uOXE>b1&~@;(Wf z%Y?DA{Ihyj#LBgzPpOiMo@jV&^LSUaLh;36Ua_gw(az)9r`2Q(6K&+XMHhDLY0PeD zDY*YN^RTz+@i0U6#O@NW*7WSP?FJ`a>gaZGOU_#)*S9Ct=Wu!6$1AplNB!*mYnp$$ zSd?0hjx`-idcI3_J&bmE2YtQ(D=kgm0fQ{#@=8^o5VY5M48 z*3nMgBaUZol{d`k)?DGZI?T9C@5QAg$<}5}$fnGaO=XLWl)WmOtYc;Gm25&tWhHxu zko~)!dcVK@NAF8oQnFKPMm z&e_v`$f($QnqGy`+4t*NmyHee_9V$Gg?^uE=6dt$_^eKz=**uu-;HT>uJf`>e`_8) z)e(h7;3s)kl2c(bLJ#-Tu|-0O9A_fFk&5bw=$i9C&6ajn{G>&?JkErTpm5{KS2}ZJ zG}U1@`C(Emt8&zgZw8kH^To&R8aRcGR%ZH)pF(TS!g|MVS#NBC>3zt~dVB1MvHuL$ zA1y;3lh*|tK7*Fo%~?y9Bs#7>%266+jrsZS#4a<|(DLTDN$v?{^bqh!c~;d86>zX_ z9i4f|_Df59Xo+g_R>e^<7A=7;i$8yV zZ6zcZ<+U(6EAs$z`?lU0ovQe$1+M6BJ%&3<=~?lF-Gb6bLu-k7q;#F%YzD%mLO#Zt zbUmxjvTt~?ukL;nQeE9KS=PCx>;K~H?E1;~u2`K$RVSUIIDfv=Yj22R)1Ty2RvmcS zB*dTa=)04Td~{sZ3<~OQb`sz}Gv78AFnF|oX{(X6`?B>l>l#y>p~z3A6Wr%Zo8OhB zzDJ|YV{bgTxVTMzDw4J)j(1$jN)&IyZ3txZent;&08p-=Z@9rk^_Snc|4ojd{8 zdUvA5Axe~n(_vNj%U7{!(#`~J-49S*xJiCGHcyhp>4>(iJ`H`XiAlsyFeLCX`hAr5 z@RzR79NBF>J(MrtLAENSVh7bkED}}v(n5vq_xH7gO@FD0$eMYFkv?AWT}b_Cav{c6 zhk*IN`}WhH@EpM;$?i(-7mp^r$66g~B0BE9OTQehxlNzmTmC&V-yD$~Hpb!|(Os$- zR^M&srp|7Umx&*m#j`-PE66;X>@OGCG3xid!nM{jNAh01`^ApK6ZL_KNe$1H*TYXc zCw(qG>19Ou!r1)udcu8Yvig_Cu3^dCwvTvVf2mb&ig?R4`HFRMyK2WmuV1zQVK;7c zdDFO4LDz~zJrj;nx1+y(jq(KT=jzhYBC{jqS%!>2@F@`zoGJ1jt&sT@S|P$${rOh| zC+P7|sFAp$2%#&yg@+zeWSwE%OQ11+26ggdcXA~WdIk6d<_UQi#t{aq6X`FtM~{Q* z4UN1eT(-rp=eXJg9f-U0@fhOUpDOrFh|gycXm(K*`&Z#G%t?k{^_CbY(dVv=;u8k|P%W>>CkrtF_49O3CHul+d52sY}qz2U_T z?v^Actnb$PjT~F18>3y6uKol&;T{JL@kda%CKy^$y;a47)u;q@Lp)2+&Vrqg@y3TL zE>_v|gY0bRWvjN75*&P)h=a(GfB+6H344|AZK11n@$E-HU@LxbcCT36i#abEe|d4& zH)G}9{&e2X7WJ15LdM&C%dS`E=BHs-j@($o-Ri`X4`ZBlQkx)AG3dfFejdE|#?a&y z4Vo*5=1od3EiYDQPeX=+Lei_|(Dh69%YwF|hT5Ct!b;LogIr7#>0*7Dmd@R-@S-=} z6?0%;G-2aBbIbZ+aW+;QV!724aOY4J7l-SUv3@G-yMj6j@)xGMY3v%q*wx*+FF9ihTr<4Ij<1jnDo39K3jM|Iii!#VesS@(f#Cupz&02LK08()_Gx4Joez` zfOe*15`I`hgCvXETp#~(>jU!RQ#X5@`>GG?7S%hAl;%U>>lU47xt=mv$4YTHaC-dT z)?T`!WI%@pUM$H{;4B>jdo(d+OAw*qdrd*cd!lYI41e&gWaWP09_zb+#9_fyX{Uxg z+Q8o~)@Ej!K8K=vei%qNeatTyxU=%WS+m=LU9Q2D;q#tlt=F3m`@frq@kzVF>TT#) z*K^+FfA&k7_1cc9XXSR5RrrHp%22FCy znlobc#MtSLW}&@`s^8qfF2JKoG+#)W?n#hB2C123t~X2jufzbW*DRTWO`>E~bT z%pMyY$DP)-@3%$P%n#O+`4mdKd9SPoU#LYLnx+H}r|uW3znNN4d~66mr^9%IT_x;% zx%k{ly`I53&RuEyPhXg~ertC?F4zTH8kwGMZk7l(^Fxn@Zc{%n+~N~owXVCIV`g>_ ziag4;S7g@HiMwX_s3Qeu+9btOw)tB>s>PR1GvL~KN|14Kc1)JwXU_Ls0;&!);=z8l z;2s30h{j1_?aG=-(mWTG_ITwF{69AyBWEUuq|bSz)tufDW!ARv(>aJVc*Vfb@%lszWF{3V&-C5_dnpa8ZD7t%=ZocF>AIh5X+D(7X=*aDX_K1*>h@-uE>)+u%=>0L@ zU0nX_>}580b678RySLjO;TJm_7rRe{>gjcy?9l}s9CWuDE}su8EXBe?NE@%|#8Fb4 zR_teI)31MWE;%o5;?&w<-3^GX&QsI7^|EeznNljXaj;3r)%y-QP-SC`fzQZ^`-k-{ zPSS@D5`>?n<{r0;$P_3J`JZ9+`ow%Q3b~eGuhr`}{@|hn`Xp9=84m9!aS>TNth2pq z6eL7B+&dhnASHokW4TaHWBChaFRo5~B%k5c;vzKXWS!P7do}a`I-?DD7?kv)hR-F} z;M^Q7<7BmKJm+RJ^5J!I7#z@YJbZB;4}=qiPgAx8ib*Qk642zS<~==^c+^n({?MCC z(wQLB=LjD~wD=wW^yS+Qdk38*jkPJuTwAr_9q|xG_5Q+Xr zBsP2eIQA1Pq#9u&ecq+h3ujPQje5&CrC$r2*xuUPgW{n1`Z#}&!4M_o=0SJT`7o5J z#_#Vo8%Mi#@jUsxyEif2JUW$R`IYX(dDGN5N>@ZQb%WX2o%*DVghoVIs<@qgCuu0e zV`w_|cEl6chM8LDez?Fm)u~9q1AAZyrzXzpQ!a7PH7@iex6aO6-Z+vTy4%srl`&A@ znOgRWqld4PWK=N#%Xs15Dzk!@&-DyQcut)x(O-o^j9CEn}Aj3a62L z9AvyaurmJSX#aj8d-p(66c(l7-oP`hbBd42^3vk979T?UN}Dw4$YY9of4 zKfi{SV7vN+x^BjcOSqo8IkCUZfj*QjFz@=Lc$}q%5i4xeXb5x{5oE@>`FVLZrSoQ9 zlMGt>gI|{&$GtO3K>T1$uW9uue6N*sk5j~Q$}TX2Eu_P!lDuF zZmQnxBvZOGj-ntcl5@?Gl}i(dcOS{F+I7?R*`P8#ew3?v`}ff2;6weK8$WdaJ5$iz z5GLsdjFh`>C*VPMp_tL<0AqQQn8=Lq#b5F_TryQdUf)QoC*oFZSj?ETnoBok8LyCGoz^J`x0?b)&`Zik;aB91K~ZsS;~4lw~qs1 zlU+^xQblXOYN;5-42UU zo~6$UNkl~c7`RFMaN_6NjS*cX^a-_#L1QfYhRp-bYzY!1hNJrvTtY@ZI1D8Dn(fWM z>DqDlJ#-rN?3WliezyFoVGYkogNtGjhjuTS66?VS|1MRmih9W`LLl!Ds<59VhZq7P zTw||oOs#Xn!J4?+CnXjyHhYuzrq6nXU;g}gA8Z#%07q=T?i~9LY}Gp2i+9&w?TJ;@ zyMgrvAcago-hQ-Iy#ZNnEy0wRkeF{Uctg0qf99o6UaY8;g@gb2S#Pk=1-3-{s#hJN za|hOUaqkZ3!lr@n+s<87=l7FDS?~CRND;kOl1?lPZ#Qj{d3r6osb64z4 z3T|**^Dpm;k9(@6RG*=B#y6X9NFI+ z-DlB9YDbQ<&qBvej*7Gn;u0k|cdwBam_N=?Xvy7`yJ~gAlHUlKfnw2Lir7M&=Cn1s z<1`jsF#1820cV&wq#Lf@V--A>Y+NBNXDsILD?|sT6nz%9<&u7k2`l}5=(bl}sF0o= zHzRbINtO>g(qNSU>WL6@_o?ohx$ATe`Msvm+3r;C(G_m9g3P?%+OYi^IT`6(UQ&0U zRzp@s=h3xLMf%U54L{xW<|`m=y!!&nindxC;KpGmm|X8f+zqI6z<>qOW7e91kE()$ z!~5QyE6GoBkyr;Hc|_=^d3C7r*FOiY{|_L#%KMYpdaPAVlS6Lc*qq6ufJScKEpQ-@w@b31L78E#N|QZNg(+^7rZ}9U{0bsGJ~)?yOf~w6EJB zu!Bx}>|M@J9X-TD#3b(BE5`4WnoX3b^ByY#h$svQV0oK2Y|bJDAy*tcq6?y026~K3 zilN++cJE`f@l*5CjzehzPBRm^W|TA-!nDe8U$v)C>HF4+3xh+++H_Uk=lRN*%3l4$ z+`3-0?##euMZ@LIF!7nq*t zV>H2+sOm>`iF5n&nwlZn2L9|A8=Ddph3YSgt&<&7{(@M_c}wm> z9rjo`Z zz!VV9>{yYM}U$LJAA`J3&-%b=ASo0Q{czOw-_D5U%V~G@+juT**@(Xsm`(SH@ zh%9}Mw^-;yEmF?W@#^EjD}f`%boG343wWXwg91!qD>0fW@fMogJA5=h;vqUvoj3&&vfOZ*LcN6MWH% zQ(VZ=j6FY+B6?B!i#x}U*C${{YGDX%kb0Zdt!(v8g`p2dwcHq3tP?st{1KHVZmZYB z<0`kWF9j0~F^S=!OmE*&no5?mTQ|sAn56$+@DS(UL9t)fB5-&8Syc3_*SO#JFP@hR zugnyfg!}Q5?`I`T&KKXdvLC$8&7b?wVWU>lK6dKs2Q!;Z_G2x$L%CcGWLJ%QUBpH< zeNt=g$-#*H;iJ>#439SBAEB-AWFktLjl9&s1?RF=A@?c6a$XAKpe3!HL6_|C3}y#( zZoW6mSCe6bPyTq z<6TH0IcMxPOM`@3fWA*gS(z9}jY!-!2pQfhjv!fq2uEpeZ?EHO9FS0PzXe|yc2BqI8eF$S zmn3p!Z(vC8T3~FrRWCj5FvYvULBl1l$ZGxTLc%J6bCYIp4qx)PW&uPIau*6B8w6>e-d*x8A3b^>YW-kFN1qm=Ho?8?^PC z4m;5}5RE(QnnOOjoYr&s@yE?KXy{V#q&Y>@XKcI3+GQpGhSE`gjp?LQIM{88h|Id>r312@3OMsu4Vww6y*zI;)~MKEWi{+Qs5(iz6aj>w8(}B96BT+ zIlg_o+=dRc3YOZRX;8_Yo63laYV~`03pm6hJlR()(&vMEM$E?@L3D&z^MT;19a*7Z zU`UFj4l~);O}R7znfns>WQ-U+Ny04ZWvZ8qXis;(vesls-y3 zqS2P=m7ZRI)xs#ob48SG>!1O9#Pn>bkx(U;Aq<96Zs#V@2}cWc$~3eDfTmuxu#)rX z!~{;`qb&3y{dATCh9hgs)M#+pyW^jI5*bahs&mJDIvn}p&@@2xh=iq|1;$dOa_(Yf z#l<-w2n7Plucz-jqe7SZPOj)~JKnAjrr2UpqTv*pk700axPYwF{x)tQ=A`&GnwKx{ zT{|&JSN)ZA8^da1wm3zE3dmfmq`3fv_oh-|6#oap-$lcXm5?tig{WIW8-D8Y?bp6Y zv=czB&7k{FTFV}={sk(L={N(sX5aybflc8L_=d-#CyggZ{Kp5AC;GM?U;(u;pv0nQ zTZ(s?T^X?}nd?mmdi9Fz)j4sUPlldzZ>*4TVA!pDxWu=nf>_#sKfV5C7fF6DD;qRk z0OOrMKm5-)%qKnh3I7yd)VB+a=nO-&5Lyg!&p|zKDLy_4VKYPID_kfAsF=xj6C~UO z!O$-hTx5b~jd_=)(sE)J_`$tCg_Sm*klFuxnasv|<0*%vzcHI674}LK`zb4Te*I9; zKANcCoX!cC^5f(^Hv&&QO5;~6Ri1B#4CTjx@gLQKvERSyW*{&PcQ!@2O z{MCjtZ3;OY6W{tp?CdZN7&L}8&^Y8T4^2(Yia1tvIH1EyeJ(|s-IEE4!uI?{h@~Sv zPtZ{CaMh=5bkSr_KPWY68xDAhExT-M$o$s3qDtjJzszo&5wcpEz_vi(-UmD$qZ(?U z$bUfgQg?A3h`1$AQYXw92&Ux2UQkOEr930`Cv~{O@-g-RLv|ZG{BX=`cl^&x^qm=7 z#!v#V4XWDyVjuu1M_`CY0a+Lyz&S)iOY1L@=db~sx0sB-xm%QM&|Z9{EAhS4Hufm)LBlH?rAEt8n2b}i+`zpRh2a!Lrl0sR`*(?3(CJ!nx8JDc1@P+T_`7>umAhbqRr zJy^{(GfPvx&2unPbNvQL;L>B<_TT7 z)KE)g*TT5VbWOk*f{r@ zs#wR9qg^x5@gk9~4TqbE9D`L*=d zIx6aQ72=xST_7SNC6$ACvRN%GEmEO2oE&>3O1M$M35h~nNP)$C9{KyVnq|Jq5=wUE zl5dYEJnQodl5CP=iA8l*mWgiiX(l8k85Q_uSo5_k=Pm^&Q~wFo>nP$7Vjhl+?Ee^x zpKWL%ihFZ0w{fOjARMA$znlaI?oJW{n&d!VZCVDA%j;ftSg-lahykDmq zhyQG`(5sgz(68ocT)OQ#aVj(Wu*m010nRndBJy*)R;$mPQVxNyT1{!Qw< zWtr#*Z9P3RPl#n?)Wf;)@G*;3-~T*?C-C8nM*?Ugt}&iQOk6cc;N%4fy zN#*^uv1-I51?&xf!5Y%>;EijVv8^0OO!mIwF>Sj`N*(@^BV!W2O)jjD9ET{xUp0On zOx~M;Hjk}0RvIF=L*uUyigCA;*eNv|nY59?%IG3f!rZ#*vO#>+Sl$03@9A{oGqLLP z4wFz7HrM@lY|&h&c=Cr(H7*0eD- z^LO_gHHwmF=UJ|zmY(Cij*8Z-@em=)?z=g<$Z9kYj7xL^=O}JUe@M#>CeNwua_8NIhQK}nM zg&ev1P0*>sk$09@H6Hz-aSUSIVZ1@!K5d+0P#}C+(QI4#&UNelzrr18x?e&~m;oGK z_4(|qtov~EAU%1q`x0l)@I#5cndqmx<4+ucp7?WT)V^rvEpumAmtV7~Z>PD(^xN5L zhxr>1pvpzeQxaBrdqz4C^BqtO>4jD&jO!W|7Lq1jziV5!(Z*C?oXAKcde?;WLNNO* zLzt|?Ew^ee&F2@d4=*;bhd5=%<@3noDGcJT%bwGYQv)_C-tV9@ay5wrM^h~@gOpX* zic-sZQ>(&z#N;ZrKwZy>A4@Vdx1xgL(5<1)dPNre=8eVRs2g9cW_2Lcc#%NgDV(e$ z26_%FxK&pA@xjufMsvweTvt5SH)@oUt#y@)-aDxvDIBH&{!hlAA2LB^U0I(M4!XLZ z1wJlvhjMt<I&p}%;J9D$A7hfY;S*bxGo5MMe$HyQtaVlq6j7)`! zk^-x3E?N?VG!^)wKDkxfA4vI4fkOSK&*ymBY3&5SqO^DK@W4&EIJM;san9n1}(H4Vocoe=~D$`>2JFgaNFvhMtX2ai>cC06MUUXy$mue8&_USFo`lI z6%7mu67Ovov$Crm?lZjJmxA|&Mt9TWzGDDcJKik*g?S4MFGJY(EVZ<-S22|~=@|p2 zAI8IWN9*ugnMBf_PG2^c+@}|83zhql5srMC{p=PnA70IorKf&MN+TgIW<0|ni1%7 zW|d@@cSM&|gj(PtVZ^=t=FCD};;4bw63|Epq8I+h43)_ta#~ZtMa<%Y*IVv~~ zh}Z6)cO!Tn5$7?6n z)!TIZ)v-q$RwC*7K4uC)1b)vM-8=dE`QaVu*!D7!{l`rMn=BS%W7V6&81r|dCz|@( zqMI!yuda7uqjBPjeEeVOK;N(V@*$qrerj4byt^9zUw7NxGqF{8G zXOKE2(POmML`*zk^xoR{n&WR-eAPgmLF~6K$^{~)vse;2h!^O^ALo0H6xYgYegrrC z0dPTo{`~m`*erJenf%15l1dn&VdQo*`7hHNO^2t_S^zGa5?eh&V8MIJ_~{h*z8l!r z)0~m~O(Avs$84o+P(3e8;^_C_9~dDbc5g`gcSm9a57AD+IYuS^c;#lq%Z4{M^-PPw z1`hO#nY~}3?7M39iTMzkgg@=l$Iqa=n@RvWD4hbn+#UC$?GZ;P>X9%>#A5>P?RmEq zP=kF1dlOJ>uEmxQUEBFvuyYDB$FgIlBw94TH41d76W_Wk!=fCi_rP_m*FfWQo``H9 z3@EsdFq94}-&)p=(=I7qI4SMBXkmdv;9ocXbc4!8WO8NjWRZxAW+|K{DDNe0^V3-p z0Hf?inBqN{bBA4}I+(WJq^7GHl z1;6!v;*j(rxH@v{jNP)yEP9zLR-I5SL-DGm`Pb8E20r5$pZNuSAl2uUsdcLgGV+;J zEVkzM&{mocM6PQ#2=5P}7pTq_w8b(w^W-ii0D2=w!$2$`GICEDlb4XoRo6yc1-_`k zl@soyuo%uE1A4HPNtE(pwE2?X44v|Or&*QTudg;2m`NV~YC$YL?CLfYk*QRb=bm80 z@xIL?kiQ~!Mo{a`POZ0j zWy7^krSyRIE22x8S7G`e461L$+ySiW>+6gPbS`mnA|6pXTilk6xJ_^t+Ey67H#{h8 znh;oBWnp4^2|*KICEEyDJbjU*2N-deimxquihP%Gjnchuq=gx+yJ1S+b>p;XvcrV4 zIunz@>~)#~Bb!n+F5ZBNvG_aZSk8oI^qeiAUCQA|=yf~UzgzCEy+@Q=pCD=c>QhrM+Ro|S#Ps#zmWM> zpO^$pvS-eBABYeLvy3pYxVkcON9gScxvJ zT3`iR2KcRMnoU*v)Yw(q5>p3dv_qLvJ4$(QlycwUwY|y{p6>d5eCVnDWXyHM2+8~F z-2rF+`!ENsdc4XA*(GzEqH*F)5UWa>q<+O7a+O4pX&?ytFSY6^kV&7#@)o6Q{zo)K zzZ8aq2m8i?>gVKluqKUCPR`B(U>^S6SS1!5NTn2b_|Ax8W&LnsCKc>%tNP~<-Qv=N zLoe`j+Jm7E%Q3xvrQy_PgV%GJu2)E=U~XVWk3(s9N-l8=t&zrj)669Q+!{;9xiXnE zQ|9=4bc=F^l^&h6G8Q6R8@ubsGq%q0!z zsx_?=FvA8PjkG806T)V4L2>1Fhzzus?>%iU+Hb5qk{yG9MZrtte>Xn8STwIH(--#3 zsr_Tq>M2ovkoQusBQQJPJ=+DNcmE4S!tLdFUNEasn!B54dh3qAurPUAhN4WA{7l3C z+m`DAVT=}~1w(b`p@*qS5Wg`tV*J?LqRs1>I;F!f}_zFmEUJo$VNX&2zkk=x9yPUleQ^lzu;1-vIOAdc6MWhu^5UL>RAXuMWIX#p2E- zhlaF$$&4dY3MByQDA$^blq?Mut}B`v_t zV{|dac5xh16(7XUuYMc!wul9*qzpJ%z-GkBY0nAM5 zrB@=^h%ZNXF3vk-3@m)O1m!@OFsI7Ic=4tUG+sZLq8BN+;R@aQSt43w@!THR) zu3aVf+!ahjm0~jOU5Q_WU7_+VSs4Y^X+))W{kVYAn`#UVXn}rrc2ba%YHrS_Kr5tq zUdXTpulPL$isIvOsMMX&G6__s$ccJXqwq=o(p!CMPwx(?)_4PGK4VD8Kgh6(eJ zk|?acLeqe0cI6AqN!w!(RfQpk?-lm6cX#cn3yes$+h~Ynny)5HlA=0%=8fJ*s-y-n$z){)}#Zp`B*nSe3{1A~$ju15A zJcoAkVcnsz^{!pV-o6d$=2XLx4vtb@N@hjui5}~E$!`ZdR7>WjQ({(P;^r$Z=q=r# zV%Kz9p?{ao`r64URUm#Z?v-M74_9~YzLf2jCC3f#x)s-+Pl1=YNPM^FX>D;&g*EhG zui0U@=ha|1_tDR<#O?5rcJ_<(U~v3LqR|nMFYET3pTk=aH!`yJSar@UX75ROrKq@x zd=7aBr_2k@&F?iZ0ssWXzp6o&0uK;&F*31euL*By?eP!0bZ%6WCJn1g!zJFU$sH@3 z*L)0xh5;~69k5P3mf9RTc(Jav$r&ZsBJfomX2he(V(eGC{MlqMM54GBGAR-Fkulfa%6@GTF zSJuc&Qwt6{?$OJ0$_?vaZ!gogo_<((VtT<@$EiFkfOB7z7LfkKNtZpFyhryeMfWA( zZBNxXby-_a8B?m&8j2&sOJ#G<0?h!Ox>Oh)KhU;dDs9R)m1fa2Fu(ZLAbz?Gk2&^@ z6OYi^=#ZwSMnIq}50_t>Y+wdvn_$O2zZA;UQ)eW60S>wk0{1@)8R28!`>4#FVwzva z$6Hk9W@)aF{4iTw3?Vb7*nIT33)& zIwHZ|G1fj16OBmY)aUj6qF{VO#HVsQa`(ag`!o5K=`BGQ8OSzuJ3BXh%`^-Mn7xux z2V(Hg;lObC*wC#>bHoph2K7;SN{2ebQ`3OPJ{Dvn0=7^vd%YM+jmK0Aa5o9zf0q^y^tyZ{G$2OTB#A0^0r6oyp?gc4fbv z$f)m$ZB5!=rgYEH?al+E09LLHt=yU8$MVl0v{5FCt@6Y;_rWpcYO0tV7aw;nmvuQs z&x3kD9a7uDlF!d*PKCi#Ov`z3Ftx(g=9+J|-5CAU8A;upfPvt#%E`dvA2LnMsl-dO zbqGhsLvYD|cVJ=Eo^reSfQUhO$lfpXW-Hub=1)Sib01=&wJXl>=kMqDuUY7YX=zf1 zgrAGAjDIb}e7(5XDUN1OxO-hu&b@|N;){s}0kgqSS51dk+z187a>m27JG^J~A0!C; zSXhp+u#c5d^a|N?eyPJfdEsV5_h&0@9u8k+NB>yY>Zjffxi&n_ZP3Wy1DFuTWYygb zp)}j;lPuu<^0m||y!R`>geVj^OXNjv?~|woUP4hXhL%_THIH$Uk+C4=i{~@C-yew^ z)@lBMeXE?k+hcrGS1Z;t1Xw^(de&*Gw|pF^eY6H&d@+ zV&{`iCdx?Nd|&l}y3t_PW+-~57OJb%6`@cNC6~0rSHL@>&0~R%2Z8|*cg(+andco z{p?cEm!J2M{ksRwTLg>w@Im$vJi00H(kkL-qB8eGr1N7R#?RKL4?hHJ@4MVKM2EMC z_h^dWfz*tK`BC{eHavsXIdx`+c+c_I>FX>G)Ylm0B@V(jPiMOC-J1dn;OH1UWwJvm znJymAde+yPk*=M^O*Ph3w`6I&m`Qh6PoZLtx8v)0gO8Yw0)aZ9Yr!{d&6$g}>-T+w z^=`tx!)A~v0oTUu?N{BKdHARWZEYMB=)Rzp$3>M9Z;M9Z>DYmO{75vzv^{{T7~hs3xVrRbeiwLnO+zF4`FJ>W8o`3Y9r`9O`R@u-|$* z97J_W+(=d2w-r0yr=7o1S)^g9%-MZi#~kAX2uDihgRc}Y!%qAC`vLWdH(t82{cxi| zRe_avT+W=HMY~ZQ5s04r&&3%h&5WY&-Gg8KjJCRZfUlX#p6ID|sXMU0+kD1@Vuij( z+CMHu0oC8Sbt3PRTHfWrbHETIuSqVtL`C)VuH!l9W=|m_D3<`BvwEU_|2``qYBQU< zVA3K0zH#Aqr1tP&ws{`1lmsk$=n+^D8dRgxh2KSBe>Q0{`=?$ zIXOA{#>9~}fJCJLmW3rL)X>5V@GI$6l+HZA@grxLbzy&5vLq|L+$Qh(d7^8l8E>9r zAV%rTGIE#jqwm9ck<-AIA286BiBprYPLY{|L*hW5=ZZf zT<16*J=PQB$G`2tq;>NUg3*B#2>(qm%SChjC3=9TnF`2G-@JJN>j~fdJ!Q`8sSfeY ze`Pu@Qt;er1(ujvgNE>kx$KK-GV}iJ5xHL)Z{KrHtCY8$3dqdIUc7-H<$ujvwMmnR z_*s_brJ{mRF8;DwSG{2>APdmYTGD)Eoy_tqAfw}#@j3Q8bd5--7TgmqxS@@Ko$z%$ zp|(2tM?YNDmwKGeN=^@3Mn2=`QVMUj`+PH``#*lXdDVxY9FJL#FFEJo2bT~sCCUGh zRFUhmx)sG{B}#Jf%ilRAj@mb)lDO13sT+Fa8tzzN4CgRiFZ}W}^*tUlyFU)79c|m# zm81B(7Jp7(%=xq28V_PfQqxawxyH(wGOhY>QJ|xlwV7ACCLlo3<24PJEbDWeD3N8; z9EClW&EpC6*ee+eKWloT=#y@%bF9CptzF=bg3eKzh%1(PMnqT29?roRFUsF@fwu7c ze^UQ5MO@K+5o826sd2tqijpvfj^*`GImC45ekjFh@7K)0UO&&_67q<%KB)wowX(AL zirEtNM?%e?fJksxmRHIN?`oV85s6*56ZTkMBe$#D(YVGOv-z4Ga}+dt9k0t+sUDLW zf}`Y>$ZK_y3w#{v_|z`!B79q*PwD;C$b%>-jwvUb&Y5J+QQ!7<1mqxWim=%v#S z-O+b1+PfOGdk0&&rPkkN>2Tz}T{yz0+{<-<#v@tsKlU+h=rX7L#68Yz5)}3aMS>$( zSCI6EpXTi3kbLmTff^65e8W&DE>$bdTZ;z}CdP?RH*Iq*iJ4S8p zk)tm|*s(mc933Vgp$2w5N>_b&QWrlgT6=~GJ0vim5fn81r6l-J-8v#Y2v+RB#q!*@ zI^7dU97wj7Tp*bY7@E0=orAn>B-jGtocbOXHk&p*ne++*n4bl(VZ=Qj9NmM!Yit^L zcF#L!{*d=TxHlEdW#n{q=^3wC`--nsoq}Wz;2c0YLDkXDC=%z5bl6}n_y-_>bD!=& zmZ1L`sbFv}#3Llk0IuX~5&QFG1ah+})Y2I}4(M;44rqd6Xs?`XAD$=!Jhvv)YjHXBwf~LsZd+HdD1@3mSZ*Tl+YEZQz^_891K``%=w(W?C$IJ zGv~9v-Eb$b$2^7&J>}T%U$7c!ux}Y3Pmx~?1V+D zXAXAiiSnspH`5dQx`9Elmmng}3?EE%bEjiJb9OoEa|lIs?&dBJ`*X44qD-q?^vE3m zMtz;TWW|&DioM==A`8}n$$&7wPd$mxf!dm5kJE;H>{;6Xb30LMib1aeGzXm9u;Zfx z1KUX;8QFoWfy8FVC3;EsYY>bF$g*S|1oTzGA4*!CWQN_8lS4u<=Wl|9+F*pH1e1mM zc6BfUh64d%NPx?qlDRq9eE^?D0rLH>dhNR)R_t8P1g{6IKf=98)5?7h>2WSgL&T_G zUED>f;8F4v-Xnwy!N!3r)F*u+fduma&E-C%mp}kT6oO|UHiD4l`9X4bY&vqjj}B>} zO~o0)Lh`E*dktpVn^*0At$etZrFt76K$^dmDy?khpS@x{1om+XB-|@Zv~O^|aUR4- z@j7zDY2|}?nj`5Jy!p?ua-;8vRkk)wZz(_1MWZVUZX7u5Gl48BB>a5&_pBA-TMc)Z z=t!;0eVctj2|lx0_a-FA2~LqHtUHo%g|zj4zMT$r%Z0U%{6bOvkhP8YM1w&Y!stSw zU_G@0lJk(}5uPvDxFQ$=$`_2ubVI%17r)9F`gAcfeg>2o$eV(|59-MhFMv_mJqmez zApi;8ha@_u_mn8sOP4l)gehAzVg_NDn>&Yxp83Z0jJyUFA7SzH*wf+R*f;E5M*SGP zBVDjE`139sJOrR$=(>Fdm=$%BRp*45>lH>=HNEqDtYg@ekM-jtP%;#C_1Ca${?(Pb znD)*m90wO1dDRiruG^DTouoq;c?nB2!B5A|4j_ zZ}AlpzxK@#(szzf$#6{SU$kC>DjEmb!iv9H0bYX2_m7E9Y6Q#ln7Fw6kV+uj{1*6) zNTi>~&WhuAlh2bcT4!;TV&P)_^iSJv%?XeC>uC+!#8@z@()>LcoX)7jc}8ga=uvrF z=aU>J;pY_hkHn+l_zDNr z1NNJanBtJT72JjI*(C6P4bYGX-1se+ra7Z3o_&B^k4q8~%*BuSK7ci#N^(YG5wmtk zXsE)|r&Xg{yrV!ih@#Q6c&5mu8G|m5o+A2cyfSd@A>QfkC+?x@lpp(38a0mp_yI4N zmA_-pAN)ixHi@3WY%;cMCxobeYmN6dbZiQi>nXNW2md4rw73OKmF=yoWYiX^Zw1g5XqTZRz{6X(T5(E5ALOSZW1HR zUIdT=48C?)*UpX)Tma3VG!ktAY%pe+w};B_D$*ItF4^KvZi$5CoA2Y^K<)(Zby}!H z`NB2Nj_KuTIa+2NR`pvy@BAdD7Fu9(YIu9(?*0KM2gl8O;347MDQ0p{FM431)+Zu? zdcE(k@z}5HaEvbHkM=}ic-T&tM}EmVf63+Tg=eYF{u(v2Rtd%Mf$73Pi%lc9XQ*_2 zCHPQG*c1V_q_U@&L5X--DmLr4}p_af7ejn0Mk$}R%J)r7cviy|z0v^SD-Qo|bd2H-?=fj#S7{a1B zWV^a6C*UZeQZUxorNZ1Mdmce?tF`rN=DUtjwOqQ{_J?s{Tcmy}0l{~TWgheya~M~W z-&Z6+hcaF?pSe|JM4!FHJ-%5wo?0yNrIgMu=zHAG<;jeYm!;pbDh5-cJ>y|ieE36d z*%lwa19-T2t4FaSs6UjhrCBr$b5>00SfVprwPPCp$QP{XVizyKWP;=8Gs5#t5ELZ< znNFU+2b9$3PvE#iTpPio_Z2W@SU-JIXzb{?74#%@es;FgIc5u|{mEPSFb#O7v2RT_ zn8P66>A&e)ylO;^msB|;{#+2HfB@0ml`Ah;`J_ci*`<%n%U^nr@Wp_L7Cbg8N9i{H z&E-lJy#g)}LB6xg^7TSBH(n}Du=)Bpghp6weiu@2MH$3q)FD}RR^_=a*(D7tT4>qbs8y&}dp3dyy`;}Do*E}8eg{uYYHq|HTP75--@M2xdbSNi`mBkBxz z2Bwxk7(>h6$ekIUhQGTOmvHADhml2$kmW z4I36uHP9?=0#EaI!DCuP(TqLv^v=+jWUIB6=5^5r7d}@9lHlaMPK#o7WiEsxszVLa%CDA%Sb-jseN*@+;_U( z`B)9DSh@|*AU**&3mP48qBuT8L_HKzluynM4K4U0x<{Z)9Vxu*b}KV;Q@Hqj1P-}$ z3lj=g4&H>0(f7?8^!84dpbscKwBKMo^-pR*iC^POEUMp?gIqGy1?`aQkx0NK0^1Ut zg6Z)1LCAXorXo!esLuY18Au9vepfOa$an}b;d{AF5M+LKjM_sDQU(F^NJ0$|9%ojo ztE-<*hl;>}#$nQHw{sMN8m~J3lt&~Y(03Qt8AY+x{rW(P@`WJS+9S_Dm*CK&c$75N z@$zYd+(>%Bp{~!d#H$AnB7X0e6wAk*T-wa~Z&N;1`sdUsdlV5_U1w;{9CYZDJl!S# z96X-!n|_`}2=WNg6O={P<2BcfeU4lcZa{gc2zH_lplSeDQvw+M(_b}3HS8~OCsYI9 z?6sIH_7Das1u#W<4#uGwkc9_a1D?&Tt>TBnT&Sk(6z_+a8Vt~207jt)j1yExK}c|v zalQK=;08xPBoFN2OfP6knkPKB)DVy1aH)L`M6qP(bF>X;7@;>rzawcEXRoN|zq!(_ z^a83Oie)k-7(0WOn|yaxL1G(!JsSzo5CAJ$m$k8P zbM&VO$v;BN0(Pa$NC89OkAqOw;q(e1K709VPdLnU0TjLJvGXY3r13nwnGjTn3y};j z!C4m4K~*F!9V?Y|Ih2u$Zn7q?=jdu-5;P|6qT78hMc#$Z1hEA|5WoKq2691k@!stQ zf1I$UCU}}p!9$F_p~Dd74!Ri_6tj}JfGe3H@N;0~@$<_V5COjaR(2Vw-<^2rv0gv2 zx`&|M1cY9uD=***tx~*nVl($7i{AkbOme!MaauuQGL8I<|3Ai^ru$SKh5-X8<))NO zrYr;(_V{j@kvZTIOrR*}pGGrj_7b1rvUk*K`cJ%$!*7ZkHgm$Js*I8GO&!z=-C#|& z9{}ewNISb~Soi$%pjB7p6}B%E>A}H=)TfxA`NM%BJu#z*c9wO8spo&SDDFALdq&<_ z8FP{Cqurja@yD%mpyco33ro1oA@-C4l^PZnC?=5~cUMt<#)tmLh0SySZaE>?E&R{v zBo!uAo`JfCaZ@?i{>DhpaFCQ8>V97Ct|<1oM5!5jwMk~i;P3B1go#1(-dDyT?_atv z+<`r^7dT3d_loXjl*9uTfgwM|40ZTuJ8 z$@3xK{r`0l>_0lllqR!nzf<==ET|~{XT?IP_ncD&OZ#8%$MFLJyIZuhFhuv721bY( zPlH*H4l=?1Tki1y5jEg{rTX8u=Pa;=5kEi;{L**WqG)+a*?`+aQQ-Xc;y!P|-|q?b zd{Gi)c8~Kv>kxb~7!zH&HHWyq06u#Mz6GgWK1-jJeNW!`{0L}FwEq7FHRc$ta^L?z zYTWa1FV4yz{$B-4pfbDV+}9XF->PkltG*);fcsR;yg&SHgHxfBq{1!S`e7cVwzcVCcBgZ3Xt zA=$%!9S(NI&;9efDuKH`RQjJUPU@C!lOs$ADR|EPG+KZhQiL)CLd`Di?u#r};(!`K z0_FJw*mY7(56D4gW&ecd{>NZ_T^i(SDEAmum4`oQU1C#qI$*^4_g5_8_{=7tb;|a+ z9h@HA({mdMrG&vVB1-xN!;cKmwI%CQ?OiQ?4U-ly@#{A~ZUnP~|CjHFJjmWJ0lkye~XoBc@IO zP@M&)C)npfU1H9L{(f)!RIyVf_e=CJ@uhzjr86)|>g(CxGzO@n<0a~87i5w zFjGM0>mZntDbNv5RDE^_9f$r8>!ggFEdAtQ{x1xgBf^d6$LpAwnOr$-6cxO(jLbiv z3$`aL@jaQ9zqo`v`fg60V^n@VkGYkV$258-n@u$^Lx)r*uVB2D;_s88Sss87tN$OD zmh%4~?akwA?xSz|C@M`V zAxShR6dIHwb=N-U@ArQ1z4lsb>q+$3vnJ1y>?q}*dGz*9t`3fH3I1I;qG0(zVpXViqJ?r8ylHbnwn0So=Plr z7_lec*tqV9?xH9IEA}vJ83n&DC8L5@d$5Z>ZhCLROnLci41{)C$aSEe5KOIn8viBtyK1I{k*0xOkU7^iuS+2b$ z3ziCgek9^3C_PsG_HPmSTT4$Lo9n)&Z0?gM5$f6ZzAZ5`{i)23ql(XAHj5FoaUhc` z)q*@pQGl)`!r(c3k+du&{Sq}@ii?ZU2u^;vkcJU1*zuqteA99b$ZAL}K(peF9~_uV zmkTiEO=ptVrmXuR78;4J>yqqMoTi9udYjXCr%P+L(#5PVRlWZ`l?}5e`+nsLfz)}| zqEy#Wk>Rt}0y$oFjfyj7A;!T2_fw+-2Ax}-ph%c z3)~T;kWml@8c#@h_tS55WTav5p@8O}Iw=P12rmn7T^*^+X}mtFJfw}MuWs&RqbFcG z-Kxl+@^j}tp$QgQP8RB5#~WE7!`8vRy0iDyY-GZ>AFS)$yX!$3l#B8|9kb8eNO_~e zT*ijM&%(#-B*SeGf78qKH}(dI*xxxX&%-%1*t2QVg5#c{(hvC=ev{a5w`kbP%Ib;D z5-CYZ4qu74r!s9SIZ4e%y6iNoQWN7Vu6|G-;~mI+U^HP&k#U`j&TU)Dj zRJ$QF-V}RbFv!?+NawOFIt{?wvibb6zSpU2`)QmGfsijOvVN5Cp3e*J^@ z_4w}RRzbeLoPt6^tIo_~tW3GC_WqF9tRt!x59?}ghiIi-f1v9vZ^%SGKJaexZarP? z_hrH3oTBjlA*oG`hfk{GfvvilI8f1w@RGifNOw5!FdrT%5u*z^y%Yc?a2)foqZIbN zkt6CKO(Mf^5e8#(BIoalZl)Ue=Fhs*#j+Z?9U5hmcGf>QU>j}GIS`(+-=4Yj{j!|~T53MY490$&Wm>EK4_hvSHqBQ83ddD-vvmd1POLZT(cK_$ z`%Ajb)b{fS4IqC71PQ_Z=g1!JaSl0`673sV6tQl>$C(MU-K8G#cRJ=SoIig;K|ulL z?nZon_a5a62Hx&koQZubtY{lCCrWdwB*F62l~$Z2}y$Z`zC350qYlG&_`sAj!gCt^=? z$Jc>+yIwv*JmfQ`r1F`LW{Bz*=SA3Wc=SMs`9fE=rTR2uD{Kbk4h-Z{>-CekZkpcL zANqw7)(w~p*&+}RoYE7SL@m9)=UdC3MT-}Q;d-ICOpdlWdGh3yD%)GSLXl_BHa!_5 zu65{;5()6R#<*=it^nV<#Ny)T7hJ#S9vyUBnH=!8&sAta>)EEx0`;IX`O8^HreYC- zB*x)3Jri-z!o%wb-UQn7mv#i?1{9jRb5r^B(JW*yMk#Tc58hGtNQ6N%Gu&&LD6JR# z&on14<6@96KRIuZycj7__d=wBehGabV(#r*+f~_P{@`GMYTW#}7<#}>D@o=QZS&EHU5d-gDf7!I9 z2R`|rdatAmaW*@R5AuADqA(R?gyj8zg`WVVY)P5Vp`(MDt_#g`8IYok-+J3r1BWOH zBp}ZlgW|fk=z-U1!`#Y~K<#EV8hC6Au;}^qBw^}Jjq3h2H= zkzUtwAFl|^phtta&q@->TPoVSm0$?2`JdDBzKSQS~6_>!n@z!G1Hhse>Pe+pLqpNE2Dx;+hH<2oJzbGHW z#$<)C6Tilq(s>0}1pAzr!8mCx7NF9;H@x11U4gYdcpWFK%PtrWc(PgS`+d)o6qkgB zI%4izKzhy|yFgG@g5FhX65#ZZT(F@0a~G!K{69NhhTC2y{EwE`{is$lmO=8}h?ZbJ zu+Y%ZP>D;;kbdYu^Gymy>LIS&HtigSi0;Cadzu>$%}1>43Z2Kg`u`=odbtU=cy+A6 z8a?qcaYxh4&gQPOtS|7c0L3_Ju(A0C>3iCy8D^}yOby}!Tsp2@B4g~OqD|xlXupIJ)gmoaB9T?#?oLjp55DBN4W`bspwA%TnG z`Rnry`&(7@Smpla>@1g``-DP_R71i~vx$XUKDTldWCkkC$p!wYRS6IdRG%YR<-?YK1~zBPGy+m#tdWs?LJv; znJggi;`#;=IHQc29sYA@WseZpWx`L{X+_*hexYkJiC-l0^aQ1E4*U_AeI#w_pEifQ z{7rqk%f3fRoY0S~I4pbh<~7@^9%@T?UlrFr%#Zr%q<-${NMJ6&)2F1B3?aT}_wdi16XKmM4F4A?ZSq&`so%zs#m)HuQ0x=APE z{B9VDKY4dGs&#+yw$=S#O_efYOLBe3#$47nyj{_hE;Z_9&ku%A#uVM_O3C7+SVZjuci!MUZb3joUPQ(UlCz(97&&_IRv?z(Ju z#Un@lw-!}g=q$MhuzadGZ$FTTFR11ec{FHDfW=CD$6}Y=1MExayCptMf12|2-s?B=o@ z9SNCHvOcd~s5W=HuFu+t&BnIdQ}=wWEH!%a?1V*T#(uv_W#wD1RU`k1)IJN?etN90 z&w{S7{?8T3Sw?pGHP3gOicdC4Y-XY*Ve>=KK0?CWo8nCcQZ6Df4XBb_wrmE25^}QH z_*w6rlXvs+mr1Mnk>}6=n}uk7gBu?ItU^OkQ+95Jqu<6hH!0?mL^=KK2}D_YKp*u@WmbKcFSJ zTD>!O6ivY+OA;W&3GWI|p)uoD?G~%5x#@h2jeCdPB?xE?Y>yPUsHCBidssXFdC52*MKRKmkQAM(l`AK4sqTBd74;k%NSS69vTa{~F0Pv4=@2%j>(A=g z{GU^+W?3uiZ{QjKGJ;$5D;Qm+^Hc8z{%-EOxL$suO2$4>`^9zf?gRCvf%?u1ueg}T zS1_lsYP~T?f8=vvHac*ytHD66rr5*Zp5Q5B@eSA7RV7UUq};;0LHh8Ysb6~!-dh9W zgoegZYPKF!n}NeWd1ynwOeL5hyN9)zd2k#=1e zJjOi(Yd~TDyCpVDrsD*yNb2~E#z$w>)}6$?{PgNZNcNwMT6;|(bg2-wp|;Vgi;a#x z36G{xIDH|gzLdQTatd?<^6$H}bwYnwS}VO4(%)j7=|R~=jD91@H^T>iG%{nnVD#b+l5T_gU($mCHBu{Z>?xi( zbA~Ye108Sqk+`i!299XP%^(A)Te_LHOQw^A3}YTwb*8sWr%!Lap zZM}%$KxGb5bw~koJ7>s)(gRZSX#=A}E9Fh@;?~1*(}MrvX>O|yT1AOzjg?V6!Sm*9 zz35H-vCb>?_v(hYWvhLeonVkL?L;|i4euBI=xA2m&oj_?25W2<=B08 z2Di=bKB<>q+e~qkdnw|yXqh-;(zLkM_Ng2UbHR9B%w*~w0w-m`8>aw^Re6)N{P%_# z+T=c=$S!!t+l3k7+%;KagmzDsIAERY!(nL4>rE1UR3iru9{~vkb?fNNHu!Zl-i3eW zj|YR1vrdHSuH#_Hj=+YT8xy3Cr6#=Ql4TDD7(a=A?rl21kb0J>YUM=tR(Ng@QK(by$0ja)-s`jq!No|qI=3N_&+-Q z)8k90pu{YfLo*o(t44}9nHPT0CvwMYC#UAkt!^pIoTaofQ)mJQ17}-S#1_-Uo`~tI zFP6%I`Mg(X*lN^ z&ZJxLb6?mTb8xI-d|ojUlJc<=Ui3VW8EAhL`P^AO%8)bserVWT1Zrhv`R~pBmv)vn zk3egN2?5NYV^B?4BvO5#LR^$$VvUfTR|+mnSpNlx0HKkF>D@Ad#@CvX)RADE}$5>eZ#qIk9CoU%L;L%MJg2X_%`QG0U^B9o~ad$AW;P zx$XV1H#_T_!{`~zrzl&D+7P-}b$A0AKRh%a#b*~uGqG9T-glWpYzPqkc!d=hRPi8i zsHgLGJE0x+ciT9wI*2=esvkb6qMFQ8`rJ@_jGZ^#OWay_= z?X$*>)|?_6O$(ld_l>$BvdMCH-tlDnOHs_wp)TxbFmCMFC&=dSpq??k@627B=}?G5 z)2XXzH4&n0a0(E@i+YS3GZue5J5!_+RB%G>AEB#S=}Tt{fI4J1$K%vF6)o| zRGD`3rmczn*fV!eyk8z7ou`*9^kUJwf%>R=@kjc%Sm0FJM>sfEjWv{|L9vvK!4csY z-aun|wVCFcH9Taugv#?OB=>b@LRA9)7VTObI@|QqH|y#5I-YUz@9MXG?Z1#(`e~_3Fwc3n$ zeGZ@8rqQ%%5rFR320pcC9j6_~Uw7ck=CD%A6~kd3Lay zJVh5^+4X(z%WGrB+*|mGu$TTwu=^jlf5Z7kWAk!HIT^*KlppGhm4Vj1w-yPcoE)pO z{*i=fnS6mmTgfk#alQ>dL~q;{EDEVDc)+I;48UW|mhOweOK0i|017>ULE%p#)t61z z`9yFWZ;s?>*{S45@)4n#kqKnN*(|! z!V^f7LRKvDgLo)zzK_GLEA)pUY>nW#!{uM{el7YlbRiO>YeFqdXUrg^1&;W zb7ulyUlhO*p0}=jV{^elZ_S8VQEklL)3_zPD*CqsPEPrEr9!`5ij}d^=!e6p6oRIU7%0w)uW`!bR)3 zY#>;^_W>IS{_&oSMXQF2@s64d-6uT3<4&j@%=z z_N7_ogg~r`7uTj+%caIF;*j`!W3j!}^dv5hnf7*m1>x^3K1c9%_bGatI;BnR{8Rab z)s5qp8ZDY=bi4r8AOG$DD{3&5h?8k_)a7GHNXiOUHfkvSUg99^d%d@GNP7?oyjG=> zTUOsEtDDO_YQCw?i{9(DV$dAK z8Ehggp|$;NtMN7EMpJ7^cxp53ek+?G1C0HcmUO>h^u;x$chBGbvhu{M9wVh!owtKe ze|VSa0rsD8($0@joBWPSbsA{97AQBdqliYyVq`As;r^XD+R3^&pJhRGZ?A` z+ICC#i?ne0SohATms{EqU^5`hd@=`*Z$+;uXGqQL$bc1cj!kuasC%1$Z)qb%Au{S) zi;!K&zUqwm1}6R6u;Q4zeowLElHzH>SJV{WExFz1cxiU8<_jq`UdEkaPbJ>#-KR9K zuGP{=to*E+d}eQgT=K7+nHg4w8{+&jGcJbbY_t>h{n&e~5-YJsJL9Ry%0y(|B&X zUGC>ooQ#!)*1cEsmgdM+=9wC$Z+3W^y6!%aYCGSB%yC>azrOJZ<`vjY6$w6cy43gL z?=vo89=G1w4*p(oyFtriSIw{JH+jR?>rWkA{J^~HTA#3!#mv#nE>}W%w{C71xaT%? z+4l!~fpmoe9z+*JLPbTG40>Z+>y86K@vu;B(bNK(8QSBl!hGCXEtk+b-4MsMR<`y{ z!CF1|mQ;&Nhw zan_@bY*`MsU2E2j7G@rF&zG2~9@jqh@}>SjJ9s2;m;?QGbDhwXDgMvb*T(V+zWv|2 zdto}poWZ&=lfOKCpR)aAKzFnM_D#)!Gow0xMyBzu)nE)At|$|BwvhYQ{*DLyBOGE0 zeI|kRla3ABJ@OGa7L@L7`edfp1EoUPloo2YY$h}kIOX;J39Tddrzjcx5 zY9BjVP-zxpD&vs z;%B39cc<@NO=kP5O4V$v%0?)Y!SP=W5<1g&VVQ>@kcYGjJ?IQvMj6RmH^rv@p_h{-Qdg zc$h;v>;NX}MIZs<`+4#tueP&M*ej<^14EytL^ZR2SER)&jwHo1v)~Nl&f&vclXhfG zSnt_00pPWQ%g6QPN%-ky{CJ?g9p#%gY$!k?l)AfMTy}TKJPNjV>}nQaE4?V4H|Mh4 z+y2*IW-2`ru7B4*b>sC`izj{?`Hlrob~s+>(3Uns>~PckF2x5%@$=k&u~{Znie{f1 z*H}h9VVo<`eLd3bsc~t@*!L@r> ziP7d)ms&6sk}M{l`ttVLOAEy-PJXvdSO>?!35!J&-x#NfaI7>x4t5~E(uncstGy+C z#g6T8^8*G6L5D~*rhWk{z=U2$+b^+-8!09sQ$a-Zc$NDF4!s56tMswwZAiO;y$NH%Y}d6D>D+d z>_m|^h#n?YOHENbk350>^os<&nePwP?l;bIlmsBK?4U{aOr#xux7I%I;Xlf1hdN}u zVmB~WuKlzEx7m-48wai2*A~VrI zZUd8=Bgms}q<`>zmJ`vcp;$fi?fr2)GIJcfZr#Omx?XSDM7XT*VryOyLHyy?R0E=v z_-qra5#}2I{j=fW$8b6iv;)B5V2~H2`}RxAy~qppu-AUeOS;nLKgu>N{PpoVGI!>R z?Vc23rD6h;Jd%!(h7=HD;7pxha&IFi9ketn)AwziHEUMG@j0=%XkG4FbYA_@tZx1@ zQh9nK{cvYZ0|UXAFJHE#cn`Er#i#xF_F;S7?6m?AbM2Q4VMc+3joog!sc~rTHAHST zdvZlHkSM3fBPODc^NhO2PmYb5iJlJ(yZidm5`<8P{g;|5ANcfwgY;Cv;wysD(>8A{ ziTw28q4LwGPg_)hkgPN{HJyL^4C^=9cX{OMC8Df4r|L30!>+~`rnbP|MA1-u1__aI zbg7g)|2jcwLzY}vOUwS8;yxkjW}&AY2i;6jaq*MI73woXPR0d65pxxe%jsFV%m;?y z7eEQ-&0Dlc0H+$+q#6IZ9vu1)^j%Fc1on?Oh8{B*)NR7^j41rsQF9F5OUAtiSU8?k zMK|bGXN;0>eU%QKYP{#ax3~v`mtgkf_aPN*)Fwc@DSfK6f$&lrUtG~tIn*Ki!(-Xw z{-*KkW7N|13&Cu70(13aSS=0cxbTb{w+i;ljxZc;##&sZ;P9Bu31|pMQ6X~eOgFF2 zl;NV;%=-~v1ZWIgF!E?dUvSOFjkkX+bar-r5x;IKah<76OY9n?Id*UFaNzN&koPDc ze{NaBTZfi1TdYz^N<)U4>Du>X6H&4-KQ}GiwkpZ?>m?XV17$bx$N`I}89KQvk6u|c z>Ga}Vys@fPT)&BuOC$AIA?#8je=XpNm*F}7_QAn5#P>SVp@5N{+UC9k~vJs0gLJ_ZGV?fa`g zA0U+;qXB7T6&oH6lQ&;$h;z#@V5Dhl*$#4p8%&T1B318uYay&;C;Y(mdoXxDmm%a! zfbS9Mk7Hs^gHSadEiWRzJ0h1rFlj)Ep9eC~wt4$s&`;a{OU>o;pH;={bNiqlR%duMiHXJ^&VsobOg56qL! z%P3*s08B6Sg0chOEu(o3JMmlI+}?2#M+!YPH1+N(pVeHxN_K&KqAVA8Rg*4rlNMEf z@Myn9cEpI_$4DOofh3vZq75@iK!BXp?m5(p{rUY<2X|t7_&&oYG<1?3I)BuPBP^I1 z<6he;ig^z@p%;>Tb}JKV(=u3`E#P<&8=CeE==?<-g=9XBpb_zXxhZ;^#4o_D8+{g~ z1W>u&^B6rK&=^*Ec&JTRvI|1T*n%uX>mEHUTMpVAnJ7}X7^o|{W4i|XyNL7xwf+8& z2oTty;-rKkbO-n}`PL(;AlUrt?R{Wlk&3{vzjGJLTPz*(zK)u#uNk!+zTS5G?RguZKM!21Tt06A z^9I2zEG!Hko71%KyZksgPj6Q5<>^nMs#%*Xl{tM%X6oFchNkcmn@KpUTlW(lenDwJ zDVv(wQ*t2i2kY!eWJRiD<0_(|(9R}@JkZsJD2Ct__#o$^9NDO6{P^*e=>KQ`{`QbM zZ?Q*y?IlqjMF5(x5g}9JWuguXJ-4RgbKr~|>((?DF1x(LFfMm2ouAo&Tt%s8D_$>= z2Vp}F^@HrUoZ%znXDF=hrttVr-}dN%IU6so1d5EqeuX*m4f-%5FCxRlf5j$ryEG@~ zpGiY3{H@GHUiBCrNDx!evGPQ7Ll_olczTmyo<}hSuQc!c;yrg98=t;2o6Xv^>F&~b z#5&3gm=K<`+`W4i746@>hedG191Kqc8z7#0A}cCEx3#N(JOhh4^WEcp)e9)0Ac#T1 z6`6^43#VeNw)#%<;okcI2brj z{Lm7dK_o8}5G?~F{se#!j0B1N?#eDGj$)_}5B6jYe2nfYni*maRU{QB*%dQG)O{h) z8zdbU?v!B;uVZ?-1Q>z`pxRN}7;s`auEoZhn^WQYzFIM# z1b>xEilxfgvu8^<)Go_4>!B2$wn(e)4eTgQ9~?+5Y0LV*iB!+b+66Wv8p>3j*eMR9 zB+`bsoTlld|AC&Ad$SP3;D~%);KV{7j!}2qd#tZrgA^`j;G+miwgy1xQOAG{-`uwN z`9;;qFk9(|gLCUwO3EmK0zOaz2EmKZ2hcUXT{WH4KmJ;htjgbxltB+p((dl=%hm1s zpJ7XlV{~swWhR4KL(hQ=fv3@kBQG>(2)}9nx9;aKYOPgHoo)O}+vLm3SXe)zpRSGV z#V;get*9&~$BW|O89;3eRFHK5$>E{8Q6HJmvkTeJ9D?NeD%G2Z%kZoSYWYdIbC6M@P-8 zj}8{hq(jQPB{D~JCtOl*ZqMW+$ekgvU;!VxhZodi!@`at6ETMeyP}8L|DXz-6_vYZ zdg}P8Q}-V3)Y9^+$}TB)O$S9_dmsO5ejIN$FD@VbpgN=_u=+B*u_Y8(>f`?_Z_NJj zfHmzt8W@CgCID_W8doa``XK58Md#$RGuVr)1d9v=3#Jh();`lNDg2oI%54MZVJ}r0 zuf5L~`a*;^TFrQ679|mL6%X0%%>1XXUS}Hoq38JO)mgSjZk|kNt5V!cvE95^><2e9 zw!hr<;(SL>-|zEZ`2Ha)*y`^nv*#iXwR!OwY5TGby1Hp4b_TW|9;DY6va-7Ft!BU5 z7h+BesAYj4Qbp{uIF#QzYMzc+_rI|KifrEtyl^r&Br7jpVQU6<#rqpgtHrnW%!r}M z3&~q6zGda`uhVTQOlk-J774e<8|C5jK+_P}CGXVv;DRoSKWmWvvZ{@TE`HuwE5 zx%0nqY{~kK(15+h@QJR(8Bc^NBEe&?MHbRSeeq)R%^eRj`h?ioL!sOJ2DhgCvFg7* zUCnAeyPbGoV7R>eSPx#>)>|&dzNc}t=lPADdH3enqUAa|%#*?L zoJVjg2$jFMipu03Nn zbUe+_PV<19!(%jH z>EI$$5W*1ufK;pvQ3R#wC#TAn)XKxW0?S*HT8qz4g=J>q6o=tg9;L1LS@(x#DHPaw z88qWS_V>Mhxm4r-5aX6uRhN$p2IXTUJ;Ep4WH4wp4rLwW*Nd>Yl*pTSjuv?M^jBaSDS9*^==5wzcC zIbEH)6+9(;8DzRkH@N-yq(}ARMD|r6k>=?D=C>_wH#`53(BM^MJ1ju9JegM2!qv1><8yND7LjBQLb}qp5O^r z2uLu%MSX&$ywK)_MDzg7<05Q{M36mUx!Zia!ix0!;H33`^X+%?>z`@8%%u#6>Tt^ZL0X^J6m0)rpd{ z_sa{3v2OQPA+rq&>iWhX)A<4JB*CkWx#)|Q!IriVq_vCXs-FFKK18o#?S-Pn*)Q{m zu~CRpKHEvQd$Rso_kgRpY%8Od>tsnobK79{VdpxCto0c(&H(I~!>+qn@%Z35Tq%Yi z>1^D{M_H%flqv4b5iQN6j+@EW7KKHdd3LQPoe35#`C$hGSKwqU?{64Irkp?pR>7ew z*8Sso6SKLjMDaso72Xa@*3lh`ox#rc7O9QYW^E8?1>+p%_e>#O5Z3 z9R^%_2Wx~r^paQK`D`y7T?`*|jGL2F9z&aCxqm)OzL-$ng`N#g`R-=u32^kh{f1A? zP8IFu*qsVl6`QqDC4v_<0SPWiVH$4UqX<{=@JogygMfHYAXE)Ku?zP5jT z6&`X2nklw&%y}pJT+4rVEEYwJ0C`lv6Z1RgXYz0!8EB0%F`mJGbj|5L&0q$Tq^1d> z5XT`FDNGy~clK0>t-rKpW?s#chu?l@TyeJQ!dC3>X7@xHLMO zl9&rM=!;Sccw}HP6E_-ZW@Gu%?^-B8Zw)Dcr@oXp3*bSyqlq*sUtsmx5i5+aXqtmX zi6>&P76_BPqep4!-5c1m=)CRWMD`L2z!F04LK8sf^-&psn&eJZmiAQ*WQ3`>YzMhO zPangzRhKXSJqyF3NSQ%YDwm37fk-|fcgV2+ zFiuZ!L}{pH3HpKsS_IS{x#5B+699!3^U4iOjsA9Tl9=h)3)}Q8KzZAw16kFojtz8f z!1F8yx?53+u0RSrCb|n1`!0kJvLs?IPBcv=XZ_8^hj1yJ6s5?08#wxO#NCzb9pZ?p z*~D0KWJyHw79t|i1UWivpy+pm5jZ0k)k!DNRtwC93knW?h9}u}x;1p-q;(@pBJwA8 z+GXV50QipLb%2C1Yi-&<;hALCld5aRC+FvD*}c0rmfi&&U6>5IR)B^`LJ5ZlK(7-y zbB1fjKYL74u!tK{)$21bnfVFtt`=-s&@w>H%?Kd+@Q)_ceuKLur0{bb5RGQ6W0IEY#eI~p2u&%28(Fy4K^1rs zJL#PW>{iLi*zitjlX1V#P>`tV%hF_aB8|mzY*L>uOLp5m*KF+^5sXELMbPI1<-XkG zJ43t&^F&Ion4IY;%dDEt`ezTY1X6MAqoGvYr_b7HMN~!sxuN!el94 zD5O50Zrwb8;k6qh6c!oS*i$8f7a4$D{W88XqW(Ux_-kJOM92q?&t%{|k5WaGr2kH-W1Tt|KO-&U(n$Df#ZqvF z`C|Vwi09=A@SHptaM7e^X!v~Nk~RwTR3w6w5pQ32#YBMRk3dtH4;5lgT!OzL^W}xD zs6n2(U;ZfTAM#&)~@jJ-lTTs z_C5bq+l{VlmN{(F7co~e)Nm!YAlH~CuA`4PYxPSutTNN!m?kC2v6D-0W#4EHjlSy< zdL_26BA2xP)QQ$yrM*YH_Iuld?Dge67cb8FymzhhrQeUG6BJzeJeS_n>3-0=dCu&o zGiPyej2g4D?(Q6}i+~MvCT3aukBqO1Zhcy*)H=f@o6(q1*w5bT>kbaq4bvloNz^QGLU|j&N%_DCtr+> zjD-5W*N5PdFLh+?Hm&$*P0nDL0zSh?nlj(h2=q&G(Gw(EyXtOL_C`O23^n@w@R#Hi zY6RbPH@PJ=QGB_dequh->lx|l%*%|^ufXl0uLI}6+KZD#MZrPm)sbTzKS3w>F~q=H zahxyXk=evYW)4afd3mGYj3q-4`Y1sow3lBL0!xR25=7XfI-p>+{# z%*S4fGnF0c1Rd7R5+KBdd{%l`UO8G?vu6dX6yf?EQ7-a|FBb)0TX7m#2x8i&gpcC< zFvmms?{l^&+rQ6+)*$8)FY8G!Y6|5o1aEjZ&cuEMhdWTX1Kk}3iBy*N@Nzi6^5Li2Z{*-@hL9#B?nHnhVaMhAL zzGcr#&`l=wkJjeRldkE{#Rn_kW-P6(O+K_?^rbc$lb)WQlZ1tXfIbseY}AIvSiM`L z|5R)!hPb#m;aoX`-6yLCuqpCzUQASvz|pgUWL6xn84==18AI z2da%L!P9WKBA4Le=9a`DCX{|gZCHRQ1YN&>|NaG_aZCkjz$PIFNW=+$EXq}~krz4^ zqnymo&zAtuNK|$ZJmCz@RH|XpF)9kB^MoiIXn!+d(Rp4BfyH~o87ZRR)@e>DBarYV z=P$tu0VQNJ>N=7zaGw-fsM!SU9uWZvvN=_3~Kv~g9sb*n06K86l z%0=&CltXEBnF4pfxpaPVY2+f6tOOm8 zQ4_xty3;^tt#kSjAV|d&^qY9SP$ea7yE_|M>TGL^$LJ~n2I2;h<}{jYtE7rIPOm}+ zgdBI4z*wqhag=f}h>8pyE}3nTJj&SkoZBgEN7A6faVQ^8qd;-O3^hTtMUnC4L7X|R zcRG7v?!j0x8Zj3vs>*p>%U9Mgi3n&vzttBaiP{p{*;F z>3}Ax$0IR7`kb-`37Z=fS%>ERaStO69^t&Pf1h=BwzEuVpmJX~WijF9!Dnprbk9!rAdV zMHFrQC<8(NRmHKvC(8?I2+e(gB2sp=2uE&xy*L2@h}h_vN!k-92dZKcf1%)0+#6+uKTg$qIqd z02yj#DNEnPuX^FbBGJ<;&Y>vfg_by@HZXlqcl~-UhC#AXnVp=31moz@qj2jkg!otQ zk;g$sE=pGh14V>l?F}+l0LO!dTw#B-?p@A+Q_I%(htOwBFpxd)`DMXPJ6kx>B9i-~ zJO|_HX<6C)OBLss(JaA|xL4FQKtw6AyQv!~+mF_f(?GWsut1@mo?ZTvjcjVc>LVVV zUqC<{m?9TRptErg4bQSI8U=@Q{>|>6i@~56gs0d@#-8mT;_mr0Phr(8G z@9H&cf@+py*))kZ4?MGYhR$z8dwX=-K)_c3NZemUSlSo^eRYo?JlN+MsyDFN+y?to zj~PX*0MvL4mTI82jsLKRb-orEMwDz)#%2#NT|jwY7vKzPNH?8xs)+$0-gO9O?ySl$IRsZHJ=ffWeJCJ~M9FW1&9Dm96b1ndPP+1b)A zJC5!GrvI#F-CQBl-od!0kT#*3$rbWxb7FT|PQ?Fc_1%rzK~pqMwGA87-@azoL@a4s z>Zs6R=cK`ldqh^WMe^SEIHFga+jFwxc-z#P@fGva#vJx4cUco~H5pcX1%MdGrkk9` z)imK%JcF#p52_rRB?p^7F2^NrcGkf#HD0qsvj6s9i@H&wipK$;{bX_-!eK0D4CD!d}t^#0L%-POMd?yp`%>{^7%RKHOu9QC)^7|2+t&%6-%BgpGrC;NikocPrbf!C z=M=@d9SM8X`93brGNr=(_Rr6iM?HRgI*M_m=k|NIMsAe|TSG{+({`Xbkb<8=DSLN$ z@Y!GT(|;aBO`BLei0@pBkC&8{Eot&xciZ$nC+va^kR`0x#DiFM%d?9FKdre)x|#efQI+Po{0;`4Jl%OC>AXaX8#ngv7-| zK)tzwqi%t_|3YTy4ZDh zrDSAyG&D4X+Js9IF$bXIX&(t*;Q%=mx-nxjK9yVZc>GD@K!)B@|*^Ad#tYiuTj?Krxafm>55 zEU4$hhYzXfqfmB09^Qa`2L}fUrq3@dw7_#oFiB4;UP#_HfF+tqiI3xS^HuE$WP$}Z z77er|%N9wqOQGf6Jp;~9e7TwPbv8s>FO={=8cfshK*QER)-xmR2y0E~pc2HzCvnR< z0o_|NeZ$d^!(d~Q_o-zq0tgiqU5^__?Ls4kQV|-Gw<%U#qbAL&iw21z`!!m@AX+h6 zXa#`FPow02`s&qWETaOLAwK7fV7$PvE-Em|KnS4%bEC*^+;eE4jj*?xL8NqqGNS$@ z_6?0MgKGi_qk-njoV7Z}N^(0IF`|?v`8h%#L=P8?2S(aaZ(@T9FE2;MbIWKv0}1DB1dyXd+(+qJ3MCErj7Z-9LW$JjJJox59$Km_*!lQ<%Mh}c7Y`2Zqnrppwr@yCh|1a6*kCXp zb~irt2MO}ZvXLS}I-j(pFpvP*k4EX`8>UuMoz;K@NCOp8&Be)}-Gg7T^lurtJ_dCVni)qQKJ4{fhMEcomUBbWnv_?P zRz24FIclO?A0E;iOwJ2Sb6@;JE#__%uyB0wnMa*Ts(AVN07y+qWvX!gI6C9qT2_{( zH@4~9w_p@y(JtTW9*y6*bLW^_K|j=Hd^@7fabo`V{QY0HXNEd;2*2#>SnYq`MPiMp z^P|Pgu%&ujv#u+1Uo{_Qy@0EDv7n3CdRWV=fn$lLin4Z!>~&)KAj8o}t}+p=0BaRH zkaQ?xQf~gxx2xGLVE4RTnE8Q`qI_dj%F*OO@kh)wGU7!NunQ31h%Ya7@g6^Z+z%GO zjQ@TEHO)*Pdrm`8_1BGGG;NEFAWfD8-nS5JlsUKV{8($2`FJx6Ix`qly>`tuM(rT+ z)U$iamr|VxTd5?ItVt}YW&ZCuAe@gWU&a_%oGwO7K7FTc5Eot5!!3!I4{a~ z*vyUFxL)oxi{d=R8~4Jw5`BHb^r=t*!3Q2AM{{ykP%uLxwbYgkW*a=eh1MHv%=3Xl z0v6u(sn>(7-Vwr|5wM9!uX6Z!{0W^kJg+nDYebp&Ox&5WT^u^c@HvZfS!j=z|03M& zDY0eKQDnGvw~TFZR`vAqig|QXf3A50!%}pqwF}DwQCBJOhvO15p@VVEQ5tfv*!uaT z%z2Tzv!talXhzj}1h{xk5N_tFw9nbu*_*Gd;tn^rK%7Ro(RQjeQnx#HndJAH#q}1< zRc7F78vMd4!x%nT3^uqB{JOX27pT~?! z37Al@-)zH$%TM2{5?V}i-s@rd$YTm{08DQ?)S1M? zn5BI_o-Os0-68sNNo95QRFa7z5kP@=Y}6ROm9Y3nB0wKYO&IL5m9mj)|Ao{)04azP z-qlh(*{K%qrSh`IDAW-In4#kmV3d5YNp{_waIxUepEt5u9;z2KF-=KAL17Zf=08?k z3PCeo1oTrXU#QsyBvxlW`v5+{MDvln;F$_$J_6uIJVH!6@`Iv?FnTwf=*&2E``(YR zJIsmW{CbJ4cYLl=x={J(u&a(Z%WXeM*9VPUEHbgCKd)aK5OIk3-#DZ$8zoo5KZ7We zyeewv@@o^W*{Y2^@#C7hNo*W@ihp5jZ(}`qXvIcU0vOjWx7BI zJC1ysFzclDFPy#m3wpInN^O_e)IVN^vZbkEO4!l-!Z8nC8?d2r-D!cAD$CtE;S%`n zqv&bvW0)%P7fKH5z)T%xc!&?jW<)2i5ZBrHsdDDY=80nJx#Tc~EvPcGNwh4->_#Wx zL-9OT^{V5%CbtVnjwUz(AYc2EB#(%xI2krxR5?n0P6!TG>+1^;G2owHo2!zimr}jd zU6B==WcWTessR3>H-Q5v7d~F*i>_Cof6YuY^BUkc9NZMj=3M|b?LI$R^i>RV z@}HEH^DJbWC_U41f5W~!c{Tgn_#cADL-cYN9%D~=diit#MX%&R84Y1G-n9Q~o3AV((K|cCC;})#DQLaD1{e=+0oej&OJ8}oiTQbcd8g3!` z*b6G+VBX`%QkaOw6Y4d4CT>fm=p6t+1Si=I;UKR>8|&0`#S*cv(;*^AP4 zji^&7yux?AxU3;Wv{Wc(NLNGbjsLLsLlVN=qUoRE*+}ojz(fUEERYR@+om-iFCs;u zDMJ6eN)9%ALw63a@INTn9O{=xyWnERtXQ$)f2@}Vj67z(>i9scR21C>5ZQF%(3hG! zx7bn+y;bCV1%Nb2mif<5u54c(B0L_#}$!Dq!UUMLRdV%8T5tP&+&+ zEF1%*@l#duRJ6DE9vs51I6c+O4fPvGXIGa7v@}G}F^)l>av7lwJQF7Pub-_6l?kaC z6RN(bKN`*2Vp`=-M_xfV&XMJ%=#m^ca%A;T779;5nL21vfIZJIF79yf8=ecCw{%D? z1?e3L9S{WIuD6xE{!q#nAeOH0I}fk|WR6KfLIFSNc4s0o7S(_0_zP7q5WhFj`Y;w*r< zvm=!C_`p2Q;IU)J67wNxYYLydynNRE`{)0(eKX7QD}~T)rktGICYO|*Ihf}A;;PQX z@ULjCk*`T<_|GK_b4(Q=17iES?Bf`ofq5n~unJuBof!;- zx@kbT72ul*cSmuO$WccP3nUJc&NcCVD!Tfq+3hdbXOYnBU zUa_wzJ*l+HFN#xj&d$b=diGr_yT?*i` zPHl$p0YnxVICT{io+)i5(4MirBxU)-O-I^R9Om6;l3};k={MdY{AiCOxorI!1!-*Ri^~0HGM=?kmhzmTPo4b;VGr;*2BFS8pD9>Q=I=wZKseLm(WWDd{XSh_q zTSm`Vi}1v|B*SC{3=drbpGy5iL0<`TP`9`}ZX$I?x;6oY$Fa;zWll~`0+`WLeq=NC zs}%3F;1{-6DoUKcF~`&f6-kOZC)y&&y*z!PU<@~RBOv4Jp|H&d zk?tLHEf2f`y_6@|elVpRO|G~|(BsgiXQVBjcx#8^?aeazs-pvHGy2BG;$#q z0U-Z3?C+=>NNqCQvpSMZ>=@;F7Pj(aH;VRGC(JI7Khd@zWP{z*3F0Wk?jR)r2E15% ziO%ocx#U|r&aN%t1NhFGb?YFzjVKs6OASyU=L!1px<@YAr#i{6-j?8my|dxQcd2jR zw`WDgX(i*9Z#z2Ba>msz5s^Q>o-=sbw3X)rb-l zHHVAQcZ(*+*68AkzRw-d$W>pqj12M4_X?dtKM;CPV3h3wUj9xAIX82g)ndfH-&*GX z^G)p2u=bFGx{!eVYB65F6o=f0j;_5^`SwlE#I82?(l!=wc0Yp^`EPCxNj>t1i5-c1 z{RrE3hhOIYv}@O=1y575Wi1AR`1xpZe;?nf5cF)4rIk_B*V< zKPdCyw9o(CZRy-rXI1Dp?bG%nGQt4HJWOm4#3ASzGZ+nkEt)d$C>RVHUx|cJLpE|Z zsm5@;QWpyR3Q179u`(&9QoO@aJ_#qdp!mc8M{~k`dd@*t*FtzRzJOHj9$?b{xC<~u z4uO>W3Lu8r?U$AvR)Q$_G;#VtE?Ehm3hEe8ZyX>RnCu6_9-H01zA10q+k*>>cRSv} ze|>s&$G6I>dFZiKPb&WOe|&nLkjLqY2Ps9eZ2U3O6crhRLCFVl6G;QvNEI~aREY-& zy_{yOSTfEBp-lX8Az;wN%D++_J@wht29xl;>MOdIFPd>DcR|;7 z;Cr-!1D9A2M)gc^h& z*$G7^ViY02e4VmW9P3PV^h7Z`r4?YD;M}4{5lQ2b(jg7vrIHeo7Kw;XGarHDJ}G0A zpAI}6XX@4Y!?wP0e)4XU?`lfF6@)T|Ei|}GmxYDe{YgF68hc1}0rMQb{}?4RgUC%? zn84Wx51ymzXXgQG>BL{JKnF+J0utYt?3Z!UZikO!-Lc)?45q6rX#oDZ7Won<2qP2q z_4OORe-A+jv>+I3)S-cN$ZUwJSJCtEz-M3RTw2TDIK6P)ltn4@#z3ho?3k(R@ zjZWBr$(FTJVLY7&lyCh=uKa9AyojOTu8^DFp=GyLSZeg{6F$|PDQHxncg|G29blVKrpO|3v}Zc@Ox#}RBEOX{VBH$8 z_p^8$bKjxE2wEY!?H@uS#eF4+Pl3j8xtVTi)2E*Ex>H%lOOctIy`Em$R(W?I%3MCs z6$;~6=m!zs3HZ`QRhQB9+8`=8NvsqYs_U(joO)w((uQy156&JQh@I~FuHlPsJuiQAV;A?5 z3FGgnJ^qnL4n5b0c?DCVO|Cdk}e%QF;iJwyelJ?UGM8Z0Vqr=7T}f7Q0tHYgwSzhBZf%^ zPf&uL0-6^GI;Qa*^P_Zi>uDQ>vUln^4EUTnXnRt!fUM@XFc9c^=<~|9EUtw|^`1lQl9&8f9t>g8{9XvdCh)WHd=ZNo4Q6_qdOj>iqq#>w5gI$9+HU{_1-=K>Pa1Wg{m~LRQ?<{5;D@X(F!x|XaZD%cd<>`{6rBD^@|9*gsg|h# zXK3pzhpp_j+$Y)zwt`*8syHvg-hNG4y+l*KeH@q`kLWe}79@91q>j$+xX&$zO6O0M zmdPx3CHYA^NBG<+4riY)gU+HbhxS>b!-``f^=T9SW7fJI7M3Yms-O4GLu^I3&Q8py zCAV!{9AqLOh;nr;_Y>uHX1>V=N0T70c!oIZ@ z;Xi27(5v!<6#xQy(S!m^fkBv$oF-T7aQ!zr1El`CeyCg(074`1mk`JZhvKDRwv$B2 z;{x@`=+3clzlyqZba7p>Y1#-2BlDJ3+wI=gE|zx!W%!rg1cuu<-SbxD(!v(6rVp-p zi0aLhCuDyKmg~`srr1q$u2c){p&{ib6q%#)iT`c8_Dpa%p{O%PEH+9$xj(X@}a4aols&ILpbhMkXgjQAFBjS`2w)C{!(j?0l;EzYcu zBjS$Liqjm|W`Y`F? znm-mHufKgg2azuJ+x#%Wf;^zvLK{eLF!U}iFLU>BAf`Z2k2`B=pj>8Sg;4v)kG9oI z=lqH@&5GKBo=%nQ)0c5d&rn}qTz<$CcVe1TV)y(E-9tE(TyzcW-cX-uGdv=&XEZGykwb`-01N+0?sdpb1B{ z(KSUmiPB26;R|$@NnKMz>$muZT8R}NyS8-~(Ub>xz9S0O_Oe(j*`pq>@S5bjcp_P4 z27`J0@`(nC4fSZ0VC+EiAKw@TbxM`F_P)GJOnVh~B zR=(p{mQ$A9wxE zVq&jzwSv|w(aZa2`uxauOFT^VZ45p?VP2`;`@R!|4L_Pn$1fKA8SFDkj^hOX!+7&b=jYDRha>H%}$(8-eEnHYGyrwy~&m9oL_JUy3)#kjra0wZ9 zdE(a|QO`#~nkl<4)q)?#Ms>E_BVPS(yk6pY_9xSBLga(`bdCEE*2>-hgbIMbof zuuO~%n5w2Xrl2Z+K1&16O&os3m*6KJf+#|a+WP0gT~eJZzcXKACQ zPBHUCsEgso?LL=tGVueYj{W&~XVk+2iWPpVjZ&}f3{35m25)4WmpY|N9>1LBbQ=}B zgM`YEo#&Tun}j=Z?_Y*n6K(!J{BXX`5a2#0FumqVuseIW9#?q{+=UY8Tv6ZzYa|g& zt^nWom6dHJukp#=QJ%WTQbShHk>xt&$*MwqR=(f$ky_W|q#x7oCzKtlximO#R6MsU zs4GZye8wsC8`scuO~B#WmF}$#3Qb??Xaa24AMDmwnEzh+Z>O=*n6CbozW!2i)ggiN zb=W6(_e8fwPXMfi)QEtBOT*9tgXpDwskFG{rzub4FtK%=odq9!^E2zLMSR3#6Ucg4 zxCPqOn_Wv}j%)T4V{FE*U;U;+9k&8+o$p@_5UV9@vozEW1)&+dHUNIWb`z%Y6A_;dj%CUt3MI7RRS?z_h`~YvEU` z-E4Qe|$)52K32~15+cSiL3P1`t+!35@c3~0MB>Ym6y;6K$HRBv*-NM1_ zS0p@ANu$2TX&ikg(nY%c zyGE@)h+SMi%)Px<#6!afd>r{``!gAXoHw3nKleQPxFbbgPF{PW=>^(t!JK3_{6jn32!sU9h|UR`%%X_H1tycJS>?T=8iRA80tt2^p#O($%)G zm>iyii3t)tz6}fc^ysI(Opa&8)EkXsNh$CXW=5=ijodLEE^0I?KR(LF-o8Yw8fLKf zb_cB9i|^#%(`@K^TEb!m4}t*YbmP3ij*K$@T8PqL+GTgKfPA@=L9)MxvlL`B4-U(9os6 zwY_-5h_9|B-mL7ueT#q98xZ`OC6KPRbRq>YsN?Hev(Bj-y&QVrYSu5O^sA4**w&m{ zJk^1!0)v7+ z2*M!@{wO%XY4M*^pQf1l8NeC|4O7xb*P=Dg7NjXLiY~3c?Il{xRPP`4%w#RgeNaa2 z12RXky1zVe%YS?1kI<)-fHonZ_l~l%4GetEuSDs9Q}-3LKXcECm@7hFL)0MLL%x0ZO8J7>Rg%N)nx@%KIZDrNy#Q86gheWj(=cVH-fZqg? zO%xpf2#5)%Z+tz#;5>|~jO18)q*hO0k`@nG}~=2KGwU*-cXH44ly zoyqX{$SsVRN{px>trec9`{CyB^|zagiDN%^sDfRTD&1YcPN$)#iT)z@ z)X?TF8CCjglRx34T(#8BlMd<4$k#6Sn3iY#)x~U}EdN0Vx9RPlc4rI-%b}Z_V$orr z|MCo22Fa;{;ot!t{S7fO?}BJ?`PBO6`Ju@v4%-Je?>{E;ynI1AQMSI-hBx94o_A3f z&e=aHP7JTZ#KSu4Ui?=9rV0i+Y+;E`O2;5n4b*@FX&99+AdZZ_CIN?WBmQ6Ex!=!{ zjk(*fu$-MB?_3m+2=cO%&H8a{{rdY?-;@6Ybuq9K{)3!^;R&QcTG~(Lb&0yUM?!VX zjG$AW9nxr|l7)7VMsKPZ9zJbZ*{8B8HQyu%+b8qeT8ge{$p5XY^k2LUTt9~YMTghN zL+e`s9lm50Eqwu~C6YiBC8MUEewG$We22L}F7xVDD6i;cnQM^5RC69&c;4IFeed6M z_6UKzGi#?fB@cO3=$Md%9*v~M}pc;iX(da3ED1UWvx1#S>k#FRZ z-R9Llt9e->3BU8zB7dM=MijG5eliwt)*~hY)2ps9A5yWzYK70CdC_IXXwuF#=+8}8 z>i;WkRB=V^*7anSfa0^iYnIwpNZ-SB7GQ9&7F&F&*G8f#y zcvSll!-{Ej$G|75V3={L{;~ok4O@9pJTiJ(Uxx)^N!byyzy~UgHgFT@S2v_E`@@`j z)-2=h{-yWdrY9dh4#%fhM%Dk;Eil-V(~bPcZC@o!kS{W@p6JG#o8|0IHZomsUfS+P zeI5szpptDU_-g`!bv;cPvF_XiNAFSJUf+3pg!HG^ON(&n3o!B_WH$XG=ezSeTrB^b z2w#&2QNbm<#Y1!U%#j5+@;`PH9v=Cqcb;fRxtkP{pnL%;zHeY434{!pa6XMk``$kq zND3SfAW~6Lf$Z&6e*RAp@7cP-r|18zRJe@=!sHZVR&5?y_@(?`I>96U^nb21o-Lrr zu#*GoD^iw{2t|tfkw`f?Rn}hYE!1L%ed=R?i4e=tQ%bKnyn_Axue-Yof#L6QYU<~pe9N*=2gm1-y$wTtXFT;J{4xtN!0+rj!+obqM|J~ES%d7JrUD} z-#Qqg9#UWv054HBnB$@E3iOBjKukX_`vB7{P}9-&>ChZV7GTCB*N>mJxbRxY*4`du zh_f5Ajz%>9#p`$-nNi;g%V~I2d=PCsQTPl@)cIHW`S^Yu?hCOnNNW;4hhhRPEH0V- z$o7Yy^$FI1n8hD=t&?HKn>5=-oV;hL!m2(wlvJ3^%|AUZ9uhj1~-;_3D4*Oy(E z4Ivt%I@P`7i~kAqD*N0#ZNtS4#Wno>4cM9Ab8e790P>Df@@H z_59kuy%a5DW~v$BRm5Fhl5bUpLGj|3o_fZk!b=b3BJV*}h zDZs*jo`z%5-rnJP^K4k={;~Oizf9ZpbEl8cbpD6_lV3~@nK9Y0y{14Uf3W>hgk5`h z-XXF8VYy~SdjIdh>H~%B)6xwrZ&^+La?CR%ZB~}ZtY^OcTkwQMW`t$>et=|0Of_Wb zHGf5d7a(nexH2XPno}hCITJ!fOi1XHn`@moai&NQf*28%j+C@`OXz=xdVBh+G5`i2 z^{_HNg{kg6>!_!vrA7U%vopDSX3jN!mpfY~r9;4trB>XtmGkxPAZ+{%+g9XKI-r;Y z9R!~|!KRjtJ_m9I0Y{!nY1&R&$~PY1If!JQ@Yx;#28)eB#8*Vp1N+?~C#LB1?irX_ zEh`UxJm9r)ny9`d{ZX*$27E4SeV_^U2ep>%(!bt^Jja4Jn{lE}>{RINg|5 ztozGN`35$tzzylSJ+L?Jk+&thvQM8rJ%XGq;>;U>77eUcWReWTlVta;DT;?5jW5>QT{_v|z> zpGWdCf7pYK0NggbBUm%jk)_TFfvZ1fmw-!gp%p-?>i|;&Lg5FoFTgO%%1sA>IYd`^ zgAh_5_zljW5`F=Ub=WHoLCGoIry*&pf&AWjFXJ0q&R=s=hHSyk;yK_&$zNWXr}+%g z6tD67i<+=r@-OCp$inikYJXnpJdc4uPnfSjgoAux3k2#w&tc}6ES`LP;&j`QY?BMD zCyAh1kDKoo)G@QbMxx2Kp~6`fj&Y@VC<@5aBwz!iD8$y**2tO?kuZSoZ^2y!hZfTc zd#UH%Rd{spF{hd)&w!4J8c4EqAigF6v)CjHs&VgZvurQZ|T*w_9^WwN^T4&DzU|U z9^d!7XJCaqGePZf9||JSt)Hl6-EhU?aRL`u6vc&DV!FDzNHw#+zhhl6avM}YAg7Z6 zB@TUfgCb~rV7F79eQ^y~ElXhHlLRm!Gy=@(?`X7&t%6{I8twvzcMMz<_zGjI8?X(a ztcK3KlnOkJ1^8@{0~kxgsd@v*MY^t2j6fpa+}S}EGN(8*%HYiZry@iEI~WCkL za%^{W@U!QH{vo6od>Wku-SnE@Ar}Yd>rNo`2j~vYjr`=YjnhrHUnfU(xxSWb@Y!s> zB9iqnXW0Wg1@hm(9I^*E0ll}zPa1jx$!u_H${(DBn`Q%2;Dm%mpAa;j94{DyPR3F% z74-!icHHX8e+F1*>p2#h(*4s*A>1~j+Xl9AJdc5MJ>|T)cn(R<;2R(ZL}a=T50GN9 zl}^DHtaygqYrOD|lGB1v3iQhvI%RyYF2{$bYIA!Vak#-{m~MGZ15`T5=^J5r$a(;Q zcZh-wbm{my)z2+oDU10i2rKsiCwY!CrU0`%vj8x7qLI0cXyHAZ{_NUR;GipXmXR80J_ zPN1wpxzdP9fgH=wS`XQk0bvrk!W3+LmQ*RHG7(tb=yhNzvhdMwigFJ@4>cK8MERtu z2&xz$OT1ZXH--#JvQ|)V8+@4iV65c~?q{{%-T>6|8LEO}nWH7(r#IsPl59H<^erfX zhw>669Ixkyy$W&K&aSS*U=l?!(KKcUt%w0Qk|Xm`F2`97KofK8Qf0z17(lEzrqdV5 z569rqm$0`z+!vgkIjKqOaYp^&VrISvgg2wugFlZVqJ=b7;&_95gTemp7E><>dLBdn zFt+m(EVA%1?%MemaGtZ%mID8qNZV=pj0+uXd?iqZ3LZ}w91;pt=&~$)U zwd&_cugxs%hwj7u2Nz4LO1)?#Yh+@&Npcx9N~b{DhGJup@h%($SQt2skZ223Z?dyz zHvzjwtTZG{f)sgDup8i-j?Gb&rpwo3jB`8a!!8;rCb4gujz*e$V zd5s6QMsS=^2DSPeFu%EH+|Hgo8wAXj--dUAX;1WU7U)(OZ>=1ZI9!|51KuKv z>GEe0&;>WS;f`rHfLUl-L~IlEYk-gdjKXT*&JOcn@L5yTOxDF#nb)XhW0$C7< zuXu1Tkj)G@H;4hRi=EaUis*yB5-ynF76CB~?*`8tY+zI{x6cfZk2C4=njoW0WD~ zfs>;*j0G_k*PIvWfgZ!$@5MJ56K$7?f9K7u>p*r`0=9?&hzl22S0w>~f!~k>%=Nxl zc*4a;z&QHtVO#I3m2(q36_(KB3BD@ED-seCwm`)I;j$@ zNOJ)((T7qaE`7;FIs;g-PX5eG=~~(q(X}-sw`s1&1nygf%`6fS`fUGTj~h9h`HLbl zyhmV!gz)%f_c?<;x{I9f=J>(8I_BaDLQUx8+JxFjj~=CXoRoG$mg(?n_T0hq%Eg|G zj(Pn=Hl^SyheI^|m#TZQtav)%?9e~rgeIKZu&B;sR?UvlxykqCb962c=`h~#@Sp5i z0-^y*-A1_|n_%g8oQ8&kl$3NA?3IzBM^dNo9#sW|{7w#3iFzLFIl*Lc1g_CnI0}#x zH&`oSV;){Ti$*2jUI)~<@jZx9he2Jfj^w&>jTW?iND!hKZ$9O~Xy^C8Q2Bs=&ADjs z0jx03UVV7E*XluqQrO{w;v>m~cB*(O@o51JRbB93$zXn(Bp%IMXJtdPFTHk$GZjVY zqbQEw(9@I`uSaw(we9fRXpu_EeV?8Uh#_HAm>Yx0Y#dBqrAMzt^ZOPp!A=vMa_5fh zkWJvLfm==O;22Qmm2$MMfWEmWb+{IYxtBzY;kCaECkC2!Wz)#v^3?!dvh?W)T{Wn) zpavqg(F-H@aVelW(GxteX}#*|;TlwnBiGctS(DGIj{tK@?l)x`eAx3q zuR+3pSkr!)gYE(`L|E0|1wu10I+cfFAb|1|umO)+02HERcuDmaZc2@1Ma0yoOPcz&w?Rcy-}#FeN#sHKS_=~VKq$d4 zgTN6KrmdfIo041TRRw1lK{u+}L`@l0JDEq4kBICis`p~^fAzD!OabZR@e8}#hx;-M zgNIM!G%2# zvX{Wj#A&|o6BK9qh7rmFIA9Q2AN+B!e8qjhh2j)7yxMS(Pe;-Zx)ZV=P_7;O&ywGLAfIEdEtgun^Cs1mTzTSt{9!Ogt zM#yy*h4JEGtIZ38fli*EPyi(QG$0a?0Gb-K9eAM10y+6GJRFj6@T|BugL5$Aqa&6U z@}fXr1cs0#T+c6p$V3u?AyEDUI2oj9eK3m=uw?*wFBkA(Y;~d7E;X{{-h<~B9sWS@ zM?*8qoj1rpnS8