From a3815d73c261560a34247109088a9c99cdc8e736 Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Thu, 13 Feb 2025 19:50:11 +0800 Subject: [PATCH 1/2] add mini r1 zero tutorial --- model-examples/README.md | 10 + .../reasoning-models/mini-r1-zero/README.md | 59 +++ .../mini-r1-zero/src/dataset.py | 92 ++++ .../reasoning-models/mini-r1-zero/src/grpo.py | 144 ++++++ .../mini-r1-zero/src/rewards.py | 62 +++ .../mini-r1-zero/train_r1_zero.py | 143 ++++++ .../mini-r1-zero/tutorials-docs/imgs/grpo.png | Bin 0 -> 73096 bytes .../tutorials-docs/imgs/grpo_2.png | Bin 0 -> 114617 bytes .../tutorials-docs/imgs/r1_zero_process.png | Bin 0 -> 20263 bytes .../train_a_mini_r1_zero_from_scratch.md | 458 ++++++++++++++++++ 10 files changed, 968 insertions(+) create mode 100644 model-examples/reasoning-models/mini-r1-zero/README.md create mode 100644 model-examples/reasoning-models/mini-r1-zero/src/dataset.py create mode 100644 model-examples/reasoning-models/mini-r1-zero/src/grpo.py create mode 100644 model-examples/reasoning-models/mini-r1-zero/src/rewards.py create mode 100644 model-examples/reasoning-models/mini-r1-zero/train_r1_zero.py create mode 100644 model-examples/reasoning-models/mini-r1-zero/tutorials-docs/imgs/grpo.png create mode 100644 model-examples/reasoning-models/mini-r1-zero/tutorials-docs/imgs/grpo_2.png create mode 100644 model-examples/reasoning-models/mini-r1-zero/tutorials-docs/imgs/r1_zero_process.png create mode 100644 model-examples/reasoning-models/mini-r1-zero/tutorials-docs/train_a_mini_r1_zero_from_scratch.md diff --git a/model-examples/README.md b/model-examples/README.md index 2a40b05..359b0de 100644 --- a/model-examples/README.md +++ b/model-examples/README.md @@ -29,3 +29,13 @@ ### AIGC/文生音频 - (TODO) MusicGen 介绍与推理实现 + + +### LLMs/大模型 + +- (TODO) DeepSeek V3 关键技术介绍 + + +### Reasoning/推理 + +- [Mini DeepSeek R1 Zero 从零到一实践](./reasoning-models/mini-r1-zero/README.md) diff --git a/model-examples/reasoning-models/mini-r1-zero/README.md b/model-examples/reasoning-models/mini-r1-zero/README.md new file mode 100644 index 0000000..c36f80e --- /dev/null +++ b/model-examples/reasoning-models/mini-r1-zero/README.md @@ -0,0 +1,59 @@ +# Mini R1-Zero with MindSpore + +*A minimal reproduction of [DeepSeek R1 Zero](https://github.com/deepseek-ai/DeepSeek-R1) with native MindSpore.* + + +## Tutorials + +See [train_a_mini_r1_zero_from_scratch.md](./tutorials-docs/train_a_mini_r1_zero_from_scratch.md) + + +## Features + +- [x] grpo rl method +- [x] format reward, accuracy reward(countdown) +- [x] countdown game +- [x] base model: Qwen2.5-1.5B-Instruct +- [x] trainable on Ascend* device + +- [ ] (TODO) large scale training +- [ ] (TODO) evaluation +- [ ] (TODO) visualization of results +- [ ] (TODO) ai-mo math task and reward + + +## Installation + +```shell +pip install git+https://github.com/zhanghuiyao/mindone.git@add_qwen2 +``` + + +## Run Training + +```shell +python train_r1_zero.py \ + --model-path Qwen/Qwen2.5-1.5B-Instruct \ + --dataset-path Jiayi-Pan/Countdown-Tasks-3to4 \ + --max-completion-length 256 \ + --bf16 \ + --is-distribute False +``` + + +## Acknowledge +* DeepSeek R1 [paper](https://arxiv.org/abs/2501.12948) +* DeepSeek Math [paper](https://arxiv.org/abs/2402.03300) +* We use Qwen2.5 series base model [Qwen2.5](https://github.com/QwenLM/Qwen2.5). + + +## Citation +``` +@misc{mini-r1-zero-ms, +author = {mindspore-lab teams}, +title = {Mini R1-Zero with MindSpore}, +howpublished = {https://github.com/mindspore-lab/tutorials/model-examples/reasoning-models/mini-r1-zero}, +note = {Accessed: 2025-02-12}, +year = {2025} +} +``` diff --git a/model-examples/reasoning-models/mini-r1-zero/src/dataset.py b/model-examples/reasoning-models/mini-r1-zero/src/dataset.py new file mode 100644 index 0000000..8218d3d --- /dev/null +++ b/model-examples/reasoning-models/mini-r1-zero/src/dataset.py @@ -0,0 +1,92 @@ +import numpy as np + +from transformers import PreTrainedTokenizer +from datasets import load_dataset +from mindone.transformers.mindspore_adapter.data import HF2MSDataset + +import mindspore + + +def create_countdown_dataset( + hf_dataset_path: str = "Jiayi-Pan/Countdown-Tasks-3to4", + tokenizer: PreTrainedTokenizer = None, + batch_size: int = 1, + num_epochs: int = 1, + rank: int = 0, + rank_size: int = 1, +): + + # generate r1 prompt with a prefix for the model to already start with the thinking process + + # TODO: remove + def generate_r1_prompt(numbers, target): + r1_prefix = [{ + "role": "system", + "content": "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer." + }, + { + "role": "user", + "content": f"Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final equation and answer in tags, for example (1 + 2) / 3 = 1 ." + }, + { + "role": "assistant", + "content": "Let me solve this step by step.\n" + }] + return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), + "target": target} + + # Load dataset from Hugging Face Hub + dataset = load_dataset(hf_dataset_path, split="train") + # select a random subset of 50k samples + dataset = dataset.shuffle(seed=42).select(range(50000)) + + # convert our dataset to the r1 prompt + dataset = dataset.map(lambda x: generate_r1_prompt(x["nums"], x["target"])) + + # split the dataset into train and test + train_test_split = dataset.train_test_split(test_size=0.1) + train_dataset = train_test_split["train"] + test_dataset = train_test_split["test"] + + # convert hf-datasets to mindspore-dataset + + def ms_data_collator(inputs, batch_info): + first = inputs[0] + assert isinstance(first, dict) + prompts = [x["prompt"] for x in inputs] + nums = [x["nums"] for x in inputs] + targets = np.array([int(x["target"]) for x in inputs]) + # prompt_inputs = tokenizer(prompts, return_tensors="np", padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = tokenizer( + prompts, + return_tensors="np", + padding="max_length", + truncation=True, + max_length=256, + padding_side="right", + add_special_tokens=False + ) + batch = { + "prompts": prompts, + "nums": nums, + "targets": targets, + "prompt_ids": prompt_inputs.input_ids, + "attention_mask": prompt_inputs.attention_mask, + } + return batch + + ms_train_dataset = mindspore.dataset.GeneratorDataset( + HF2MSDataset(train_dataset), column_names="item", shard_id=rank, num_shards=rank_size + ) + ms_train_dataset = ms_train_dataset.batch(batch_size=batch_size, per_batch_map=ms_data_collator) + ms_train_dataset = ms_train_dataset.repeat(1) + ms_train_dataset = ms_train_dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True) + + ms_test_dataset = mindspore.dataset.GeneratorDataset( + HF2MSDataset(test_dataset), column_names="item", shard_id=rank, num_shards=rank_size + ) + ms_test_dataset = ms_test_dataset.batch(batch_size=1, per_batch_map=ms_data_collator) + ms_test_dataset = ms_test_dataset.repeat(1) + ms_test_dataset = ms_test_dataset.create_dict_iterator(num_epochs=1, output_numpy=True) + + return ms_train_dataset, ms_test_dataset diff --git a/model-examples/reasoning-models/mini-r1-zero/src/grpo.py b/model-examples/reasoning-models/mini-r1-zero/src/grpo.py new file mode 100644 index 0000000..abf3c9a --- /dev/null +++ b/model-examples/reasoning-models/mini-r1-zero/src/grpo.py @@ -0,0 +1,144 @@ +import numpy as np +from typing import Callable, Optional, List +from transformers import PreTrainedTokenizer, GenerationConfig + +import mindspore as ms +from mindspore import nn, ops, Tensor + + +class GRPO(nn.Cell): + def __init__( + self, + policy_model: Optional[nn.Cell], + reference_model: Optional[nn.Cell], + reward_funcs: List[Callable], + tokenizer: PreTrainedTokenizer, + args + ): + super(GRPO, self).__init__() + + self.policy_model = policy_model + self.reference_model = reference_model + self.reward_funcs = reward_funcs + self.tokenizer = tokenizer + + # Training arguments + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.beta = args.beta + + self.generation_config = GenerationConfig( + max_new_tokens=args.max_completion_length, + do_sample=True, + temperature=args.temperature, + num_return_sequences=args.num_generations, + pad_token_id=tokenizer.pad_token_id, + ) + + def set_policy_train(self): + + # 0. setting models' training mode + self.policy_model.set_train(True) + self.reference_model.set_train(False) + for rm in self.reward_funcs: + if isinstance(rm, nn.Cell): + rm.set_train(False) + + def get_completion_and_reward(self, batch): + prompts, nums, targets, prompt_ids, attention_mask = \ + batch["prompts"], batch["nums"], batch["targets"], batch["prompt_ids"], batch["attention_mask"] + prompts = np.array(prompts).tolist() + nums = [np.array(b).tolist() for b in batch["nums"]] + + # FIXME: unpad + assert prompt_ids.shape[0] == 1, "not support bs>1 when generate task" + prompt_ids = prompt_ids[:, :attention_mask.sum()] + attention_mask = attention_mask[:, :attention_mask.sum()] + + completion_ids = self.policy_model.generate( + input_ids=Tensor(prompt_ids, ms.int32), + attention_mask=Tensor(attention_mask, ms.bool_), + generation_config=self.generation_config, + max_new_tokens=self.max_completion_length, + use_cache=False, + ) + completion_ids = completion_ids.asnumpy() + prompt_completion_ids = np.concatenate([prompt_ids.repeat(self.num_generations, axis=0), completion_ids], axis=-1) + num_logits_to_keep = completion_ids.shape[1] + + # Mask everything after the first EOS token + is_eos = np.array(completion_ids == self.tokenizer.eos_token_id) + eos_idx = np.full((is_eos.shape[0],), is_eos.shape[1], dtype=np.int) + eos_idx[is_eos.any(axis=1)] = is_eos.astype(np.int).argmax(axis=1)[is_eos.any(axis=1)] + sequence_indices = np.arange(is_eos.shape[1])[None].repeat(is_eos.shape[0], axis=0) + completion_mask = (sequence_indices <= eos_idx[:, None]).astype(np.int) + + # get reward + # decode the generated completions + completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True) + prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] + + rewards_per_func = np.zeros((len(prompts), len(self.reward_funcs)), dtype=np.float32) + for i, reward_func in enumerate(self.reward_funcs): + rewards_per_func[:, i] = reward_func(completions=completions, nums=nums, targets=targets, prompt=prompts) # Shape (B*G,) + rewards = rewards_per_func.sum(axis=1) + + return prompt_completion_ids, num_logits_to_keep, completion_mask, rewards + + def compute_loss( + self, + prompt_completion_ids: Tensor, + num_logits_to_keep: int, + completion_mask: Tensor, + rewards: Tensor + ) -> Tensor: + + # 1. compute grpo reward and advantages + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(axis=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(axis=1) + + # normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # 2. compute kl divergence + per_token_logps = self.get_log_probabilities( + prompt_completion_ids, + self.policy_model(prompt_completion_ids)[0], + num_logits_to_keep, + ) + ref_per_token_logps = self.get_log_probabilities( + prompt_completion_ids, + self.reference_model(prompt_completion_ids)[0], + num_logits_to_keep, + ) + kl_divergence = ops.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + + # x - x.detach() allows for preserving gradients from x + per_token_loss = ops.exp(per_token_logps - ops.stop_gradient(per_token_logps)) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * kl_divergence) + loss = ((per_token_loss * completion_mask).sum(axis=1) / completion_mask.sum(axis=1)).mean() + + return loss + + @ms.jit + def get_log_probabilities(self, input_ids: Tensor, logits: Tensor, num_logits_to_keep: int) -> Tensor: + # logits: (B, L, V), prompt_completion_logits -> completion_logits + logits = logits[:, -(num_logits_to_keep+1):-1, :] + + # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. + per_token_logps = () + + for i in range(logits.shape[0]): + logits_row, input_ids_row = logits[i], input_ids[i, -num_logits_to_keep:] + + log_probs = ops.log_softmax(logits_row, axis=-1) + token_log_prob = ops.gather_elements(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) + per_token_logps += (token_log_prob,) + + return ops.stack(per_token_logps) + + + + diff --git a/model-examples/reasoning-models/mini-r1-zero/src/rewards.py b/model-examples/reasoning-models/mini-r1-zero/src/rewards.py new file mode 100644 index 0000000..96db89a --- /dev/null +++ b/model-examples/reasoning-models/mini-r1-zero/src/rewards.py @@ -0,0 +1,62 @@ +# reference to +# https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py +# https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb +import re + +import numpy as np + + +def format_reward(completions: list[str], *args, **kwargs) -> np.ndarray: + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?\s*.*?$" + # add synthetic as its already part of the prompt and prefilled for the assistant to more easily match the regex + matches = [re.match(pattern, "" + content, re.DOTALL | re.MULTILINE) for content in completions] + return np.array([1.0 if match else 0.0 for match in matches]) + + +def countdown_game_accuracy_reward(completions: list[str], nums: int, targets: int, *args, **kwargs) -> np.ndarray: + """ + For Countdown Game, evaluates completions based on: Mathematical correctness of the answer + + Args: + completions (list[str]): Generated outputs + targets (int): Expected answers + nums (int): Available numbers + + Returns: + list[float]: Reward scores + """ + rewards = [] + for completion, gt, numbers in zip(completions, targets, nums): + try: + # Check if the format is correct + match = re.search(r"(.*?)<\/answer>", completion) + if match is None: + rewards.append(0.0) + continue + # Extract the "answer" part from the completion + equation = match.group(1).strip() + # Extract all numbers from the equation + used_numbers = [int(n) for n in re.findall(r'\d+', equation)] + + # Check if all numbers are used exactly once + if sorted(used_numbers) != sorted(numbers): + rewards.append(0.0) + continue + # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace + allowed_pattern = r'^[\d+\-*/().\s]+$' + if not re.match(allowed_pattern, equation): + rewards.append(0.0) + continue + + # Evaluate the equation with restricted globals and locals + result = eval(equation, {"__builti'ns__": None}, {}) + # Check if the equation is correct and matches the ground truth + if abs(float(result) - float(gt)) < 1e-5: + rewards.append(1.0) + else: + rewards.append(0.0) + except Exception: + # If evaluation fails, reward is 0 + rewards.append(0.0) + return np.array(rewards) diff --git a/model-examples/reasoning-models/mini-r1-zero/train_r1_zero.py b/model-examples/reasoning-models/mini-r1-zero/train_r1_zero.py new file mode 100644 index 0000000..67f70f8 --- /dev/null +++ b/model-examples/reasoning-models/mini-r1-zero/train_r1_zero.py @@ -0,0 +1,143 @@ +import argparse +import ast + +import mindspore +import numpy as np +from typing import Dict +from transformers import AutoTokenizer +from mindone.transformers.mindspore_adapter import HF2MSDataset, TrainOneStepWrapper, auto_mixed_precision +from mindone.transformers.models.qwen2 import Qwen2ForCausalLM + +import mindspore as ms +from mindspore import nn + +from src.dataset import create_countdown_dataset +from src.grpo import GRPO +from src.rewards import format_reward, countdown_game_accuracy_reward + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-1.5B-Instruct", help="pretrained model") + parser.add_argument("--dataset-path", type=str, default="Jiayi-Pan/Countdown-Tasks-3to4", help="dataset path.") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--learning-rate", type=float, default=5e-7) + parser.add_argument("--lr-scheduler-type", type=str, default="cosine") # FIXME + parser.add_argument("--max-prompt-length", type=int, default=256) + parser.add_argument("--max-completion-length", type=int, default=1024) + parser.add_argument("--num-generations", type=int, default=2) + parser.add_argument("--beta", type=float, default=0.001) + parser.add_argument("--temperature", type=float, default=0.9) + parser.add_argument( + "--zero-stage", type=int, default=0, choices=[0, 1, 2], help="stage of ZeRO optimizer parallelism" + ) + parser.add_argument( + "--fp16", action="store_true", default=False, help="whether or not to enable mix precision with float16" + ) + parser.add_argument( + "--bf16", action="store_true", default=False, help="whether or not to enable mix precision with bfloat16" + ) + parser.add_argument( + "--is-distribute", type=ast.literal_eval, default=False, help="whether or not to run distribute" + ) + parser.add_argument("--rank", type=int, default=0, help="id of card") + parser.add_argument("--rank_size", type=int, default=1, help="num of cards") + args = parser.parse_args() + print(f"{args=}") + + # 0. set mindspore context + ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": "O0"}, pynative_synchronize=True) + if args.is_distribute: + from mindspore.communication import get_group_size, get_rank, init + + init() + args.rank = get_rank() + args.rank_size = get_group_size() + ms.reset_auto_parallel_context() + ms.set_auto_parallel_context( + parallel_mode=ms.ParallelMode.DATA_PARALLEL, + gradients_mean=True, + device_num=get_group_size(), + ) + + # 1. create dataset + + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + tokenizer.pad_token = tokenizer.eos_token + + train_dataset, test_dataset = create_countdown_dataset( + args.dataset_path, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_epochs=1, + rank=args.rank, + rank_size=args.rank_size, + ) + + # 2. create train network and mix precision + policy_model = Qwen2ForCausalLM.from_pretrained( + args.model_path, + use_flash_attention_2=False, + mindspore_dtype=ms.bfloat16 if args.bf16 else (ms.float16 if args.fp16 else None), + return_dict=False, + ) + reference_model = Qwen2ForCausalLM.from_pretrained( + args.model_path, + use_flash_attention_2=False, + mindspore_dtype=ms.bfloat16 if args.bf16 else (ms.float16 if args.fp16 else None), + return_dict=False, + ) + + policy_model.gradient_checkpointing_enable() + + assert not (args.fp16 and args.bf16) + if args.fp16: + policy_model = auto_mixed_precision(policy_model, "O2", ms.float16) + reference_model = auto_mixed_precision(reference_model, "O2", ms.float16) + if args.bf16: + policy_model = auto_mixed_precision(policy_model, "O2", ms.bfloat16) + reference_model = auto_mixed_precision(reference_model, "O2", ms.float16) + + if args.zero_stage == 0: + optimizer = nn.AdamWeightDecay(policy_model.trainable_params(), learning_rate=args.learning_rate) + elif args.zero_stage == 1: + from mindone.transformers.mindspore_adapter import AdamWeightDecayZeRO1 + optimizer = AdamWeightDecayZeRO1(policy_model.trainable_params(), learning_rate=args.learning_rate) + elif args.zero_stage == 2: + from mindone.transformers.mindspore_adapter import AdamWeightDecayZeRO2 + optimizer = AdamWeightDecayZeRO2(policy_model.trainable_params(), learning_rate=args.learning_rate) + else: + raise ValueError + + grpo_model = GRPO(policy_model, reference_model, [format_reward, countdown_game_accuracy_reward], tokenizer, args) + + class TrainNet(nn.Cell): + def __init__(self, grpo_model: GRPO): + super(TrainNet, self).__init__(auto_prefix=False) + self.grpo_model = grpo_model + + def construct(self, *args, **kwargs): + loss = self.grpo_model.compute_loss(*args, **kwargs) + return loss + + train_model = TrainOneStepWrapper(TrainNet(grpo_model), optimizer) + + # 3. training + train_model.set_train() + for step, batch in enumerate(train_dataset): + batch = batch["item"] + prompt_completion_ids, num_logits_to_keep, completion_mask, rewards = \ + grpo_model.get_completion_and_reward(batch) + + loss, _, overflow = train_model( + mindspore.Tensor(prompt_completion_ids, mindspore.int32), + num_logits_to_keep, + mindspore.Tensor(completion_mask, mindspore.bool_), + mindspore.Tensor(rewards, mindspore.float32), + ) + + print(f"step: {step}, loss: {loss}") + + +if __name__ == "__main__": + main() diff --git a/model-examples/reasoning-models/mini-r1-zero/tutorials-docs/imgs/grpo.png b/model-examples/reasoning-models/mini-r1-zero/tutorials-docs/imgs/grpo.png new file mode 100644 index 0000000000000000000000000000000000000000..61088b7e6872dc91ebd993897872dc75b5fc89b2 GIT binary patch literal 73096 zcmdSAWn0v3)Gw?k-7s`XcOxMsB_-X`4bmW8A}}D`AzdQUpukA?&?Q~cAl)FfPp*64 z&;AV0am*W-ng5)**80Wztg0-FiAIX{ zpU6u~XnGm#Wg&V&wx+vV1NCj<(Gcel5F-d!mk1TA&AJy2Bc!TV50qUznGZPH4e{li zQRH9PIM<&J!FCEWs8Q)=F=w3vMJ`d(y{$cXYK4WjnlA!XAu?+J^U92PsO9*7f6@rB ztpC6FFbG^D@|phUr3^(A`rj`>7^Olp|M!iE#5_sm|NAsZcb_VC@&Els#Hz)&IF0{( znhTF6ocVvBCP?!C+n@2q_C8)kq!jV`m?7Y_I$oqG=y^cCu(;@Qb!56XTZ=qj=g`>` zO;&9+L|Ud>jSxx1*|Pw1<93+ii=q;zq>D^SL?))-@RWWxM~_E6;XLJ#MOt!hf41$b5OYOoZ!kvM!Hn(&({|cl^`u+I8^DOZ0*_ z(UONt%`(Jlf-fEBAc42X$bWyA1ke4n8yNZJzMF5cs{L@x46gZt=g7HlPskQjJOM zd2pM(zq_#%?LAm()~{Q;n6GzAy;I99H*C~?P8LdrxKJpc$mMfpgF*2Ma`*2%v9%gJ zf#x66b7shOJ-^N9>wAuS|IKBCQf>C@3^IQE@Xeu&f#tap4Mq%dLDl=Z!24_2`>#V8 z0t4I>qCODo;mr2QQk~=NG1^`F$7_(}`QJHDl0y)E)6&v%baaG+E1~wLn$Klb8kY^i z40#7(`ifH^w{i%fKeJKuJKvwzkzN2BU+r`L9#;mU)Ei6HJ6mhLvavzm5n}W`44W## zlY`~-Y)1pP`Ap2H#ZRy4!`YNt;jxGY8;9X9&F3u50nI+=cxrO-(`9K&Ew$t13j8^LQf zfLKG&&+ybT)wo9WyW*rGc+f(bHNMJu(KZ{NS*eAL;>%L^__M}@2K z{*WM}*WlSYM#e_gK6gDa6s(luXu>U2Z`)-nsha&TEQ_3rX9~89ZPo+$>Q%tutp4wp zHZYI!7J7SSo49tTD|^p&CS; zpLWyE|ITdFQF7~3@_s;kG4=~Rr_QGMZO5fguG0o&^Lj6d7;z338?6`1K{z|1%zd0w z_s485CnqQGG%MmYcsy;lua(d5K^!5d@OY`@2w5jNr%}^4?C}UGQ%c(9jI%xNoLHD2 z#OLhF6(NWw)-3iq@$Tj_HJ@uR(gLNTMSws?EuZ5t+Id%+WaBGN)N&zGg{Mz<;RtAy z9h;_U;piogVK@}@(tRg-u9iu(slXGGnEX1uR=@&|gSyBUxq9S3lsRD936=L^EDd%= zkUV7tkIW={+<$5LD`&;_HzOpXHW(VboiFs<ZEAL(ra?o?X{Dn1?gI2*x8m?l)R<_e`5$KGd%VP++?qd<8e zNG_+5AJ6jpM=x8&($I;PA%myl>m|W&0x^S2Z-jQ4b@nhA*=PQKTKO_qUqBR_NO8NT z`e#gWNJJG)q|U}<9*MmH>;)d|zeF<8CT7AW*m1<=LwHrCS8Jt&N^SmZw?v~jLMMn4 z{`9#aMKVt9#s$9e8D}B^|HSN?z?w3*{`0g_O)0$De%k@3$)q?-i9O3sX z{A(g7ew3nt_|dOoySGXr%+HeOrv`GHXmv5dz1=a!3d1rmC-l@NGUd3%6YXm%5Vaqd z8tHLgo7@~1e}z6n#*)Lrj*Lx%*@LLXcz=K6RPx&s#CgwGH?8iMYF$c(X!}URkTI9z zDUHn5W8I-)kK-xcY6ioel?;pxXZZe)b9JV86hiD*0k?uFoW>k54-9K7o3G3V4Vfz$ zF-LK%;HM5Bb2PoADK#S0A{{&6e0MrEj@@St36qo#=D1hVy-*MD7|^Lyty9*yRbN%f zJZZ*O5hrko;p3Eo!@U1dOWljBR!+XK2`O@Gmz08do{grj#$*fo9O()r=D29jVmP=^ zgA_y0_1_s{8niPfc>n<^JS-af!=C?St+mR|^cyvzS?gU+gEN_gvgA>x1-z@J!1FIh zUWsZ~-pk&&@Fqub?AjMi&Zo>Ai{`pptk>Qvn~Bpsf{lhPmv~uWgEIYllg<-|73T>| z%1F%zXM@#VXs3EZ-_zNYyN+z3_O4NnxBKF8rri}>u~7O4WDH`QKMn;S&4Zq6x0)A; zSa(`(nbHkvK4q{FzoN4GKI4}x0;trX) zRz=}HIQ+fQNX+L-^Oyg9OrCl1>iUMTn(`pGpZiqAqWaKD)jY{Ay;>$YVa@T$NkY#? zXBf8Z>E@8kj7vh8Q6tDgibZc97su-{{o--Vc)lO{@)A!{W~^s{f4Bhz8sdY4gG0W` zAtN^9dc~8Ub(F9d*ZPlxb7>lQ>Zw_wN0j4#AvFGVB)tFS-xsnKO>)Vv*rBmG7gO0V zbZA!NJ<88}n@?5_ukszXvaebnZp1$viesckPuNo&4sQBy38gr=$gHv@V^I5d`N$_R zp+pvP(YYNgz*bAvnnSXMH|LHiJAI=cmyhak15pVE9rR$X^uIh0Ie54IX3P+UH=C+o z^qp>vlxGNv4on?%O+(}6J%024mW+8*nK5v`M)PZ1rtuC<&~=JP==2u)aCaD(#^YJS zy=j^2ux+|#H|sOkbkNlHz5M2A$2!OK@EE;j85iH$?0o7D?|a{|#Aivjba#rt`j$6s zg61TtzW)&KF2BsbOGSQ1gvFVM0xQ3lbexFkoKnnhLgvc30AkT*J0EP3Abx(QPrno9Ax?v0)5SPCJM_Nu-;@DL_?L2B!_U$9jH zZ>XK7Ts=x3purn5FQ{4zAmR8meJ7b~(E8$%tY1}g0vrr&ugVo^yk@sD#{pm5^1z4t znFbfy3ImwgB8XG;3I!l%CB5c?B!nV^x|5N8@E-FD!?F;*`>+M-Xj=H$f?7*m8<>t!0%PnKq0+V>WnvtqZUIN$t46ozWP|XIz zA2zt4`eejc8XTbDHhb=Iut2gC$pjEa2+`v{%Z1~)&AqF4jQ%@WT3|czHI2*klV-Z_ z)lr`Nm2X&ch{4%56v|5AjA|!TJMab2xj`{)wsrXx=h)%>(!>aQ( zm-YS(Dk{G#>MuX+!~QAhn{DP#Zj0NqMWmK|$)PfcMeM zfsdElW6$*K9g|fR=(#TNTywg@gZ5fCC91~>q@TQWyiWDn92}K~Dy!QU$t9CIygF~o z*>4QjVbzP!)U98`aTGyDip>tO8OvWA-BHBL7V`Kf;&M&pIbu1F4+xyF>o~@VX#vI>D`#c(Iy8*i@{8#lXEZfHS0m5k(HEZ2Mu?4 zX?)hzx5_JU%WCmENaU0wKK81-)W?`FJo@+i5duEgr|X;{+C2e!QK79Gaz0lJKJ4O$ zN*4^&_2#`;LSfwz1c@{xA^s?^D15x>3PW__fl(8V;x_E! zVf%%d@(ph&mM@wRp)l+mFJYRHSKI#fOpyb%yvb`97Q(?}&+Rg9;e2_Kh`=O}jle2! z*q-u?vfh(!5Gu)N5)hKEUGS~gAuCst z)2OQ7{Vcr=>!jlDINmvg!*)He%L@grgR#ZpT0;F^tE6&dL7Gv}IINARdzsPl>W90# zWg@=*{-*W}lZ-#9XZ;@~fDP%76{}iQt@Vi6#op|RNe42d4%Ngf9*P&5b%Z7qLFzNe zB%=E#L~?jqcPj8|=3~lph|6?NO~MA`A+DWggMK6mq0}UPbyeupcZYMmfr|G>!l$?n zR)>EWX|1lj+77dk>IC2AI5L;fA4dD7zUt;%czc(8i&T@Bao@csEL)=dUo`iJK?ttH z0`lejR}B^WrL71tDcJ(EpL2q~cX^AU~mqK9DsUUB;IS>}{ zIfMUFP+RwJUyz4xeo5vMQv0*HX+`lJlOA2@ds{UF&zDx+MpUcxnuqi9d`$2OeB=TAc>0Sn15xpu9Jnui)2=SQv7fribF-R8aZIz9l@?euP11E{dG|;hdYw(aLVc1 zUdDhYxllAldTrl{irD8Mx@G!as9~7PdKq>Ar`o%edi!ZaeQZi}9P6{atdsLZRYgjC zQ3IE??$qzeW!my9-OhtlJlCz1ml=F|Xv|+b(7sHao-m-x!;k%yj#Fr4(EJ|mE@?z& zo!FQJ0j4Zmjn_^RX#d{CsFIyz_Npmos!R{xyeF!?+Pt?+3V^mHA#f?W<&Te#+r=Ml zgT?Q!q(q!Dl{Pz(6kS;_==~%QO3pO(n;gl&Rx_@?@7@%BO3AZXR=!X}q&Vf&+ap?0 z6LmMgS-wDr;0Q`mnwIUcIiJyIiUNB{cg}$j<#E583%BDzOo5)-cSV7;hs85KZO;by z_Y;o6N7{B+FW1wy@646w*I|lW?@#PY17-i19BZ5Eg#9 zvyKY#xNW{A8HgxjxVhL*H2NHYN4kAp6lY#!{6?&56w8h8agPWd16P7ne2?qw{(OC^ zn66=?M>n7dUM@2t=v4=JHiTrp+a7|r%GYT3#N1H^RUo5lc7@|gu~a7jE+$epgU?ny z*Cix0zzigoI9i2RcMuHdkrmWB&`Jw=&^}Q&?lr>9y3rO|{O4c21z{r(z>*}nevmef zE|&uf-cNqFnJOaz>|YT!8ZL%nN-g!j!D~BBxJ;EAoPF5?W4Z8-!qE4k7tn5Wg$Z{n zJeK_-4K8a)n#LO{+Rv4;MZ9x+M%Y+KXc*-a=;HSQsYJU3i4Y5DfzUT`At#n|z4F)# zXVLne>$WvK^V2RB+Q1rtNB!R*y0U~lt8B(-`h3mBklc7hL`BOxA;=Pk)_@CC{b8fF zKA4VA!ebFyv9D-@G(^m}`-Nz$SED z-|*K5+w@y5y(*KZ{VD9M(QVgphAn>H5KbclhBqKQka8)j?LE~eWz(;F>+^x<%>QJ) zFJYjJDPolg6d?+9ZHtQ(fa;SfQb?w=kMeeURdC~&i<7Hb^)uh{{c2;*_ma5C*XZ#0 zfp;YHmoQ`ls!h7;jVF;RXMZQrbt;X1!hD7?yXl*u0DSd+eEKZ2gE;^DHMCM{giDH3 zD5yyBdaXh~(HwsjO<4uh>1b!(q^wfiiJ6l?uo%k^-p z%bYIve~GCBnH!%l%jAkqYA%zY<&f#mSLLkth8vKh`|f-m`>_1y*lr@TYl(R+qP`bN zmlgR;3SR#W<+(3*^(tL!*X)FX8zT+XSqK;QF5s8PGo0XY)HNba`^ZaEOg+liOPva- z&1QC<a40HFAfF|@rZpCpo1m3S~qSO@J3Q#>|QSPdgB8Rxo6MnaZG6K>cN*l?M$*#F}@2}*+s(N-$mZH1AXe5SnrV&9YDiW-}zRQ74l>TM?&inAR8skRpz zJ>&f0a?+?mq=@51H?PwpCZ#RzzHHTh+}I@c8R^%mpx!+tWMXTbhswILWdtEJ))lTa zP-9_dwS*9Jo9Ejtw+Dq}y8E79lEY4DjPq#X)#13>xegv{P!o9& zSDRj-Wr_OMtxti{z6llJBQx1F){X$DOhe0Il;-CO>szTDR zK}bXIXuB^=BZk}@b$48*1C#>gH8@lZnaUa(cTzQ%f`OR z5%VvLq!k4li{C;th~r22W^5@B0J#t}A@rrCr8Hjan2gcG@r9Lr9m0EjDY))N)l(XFtu>JW%c(#f_92pvG1(DZrs9qZ@i zKb4Hhw;SK)ox&g{a$BBB^K}%wvrhbJPsHUq>IHD!R|a>dlkOj1En{&9J&H#T-T@WIfh6DI(-7Hx9gUZ+qFw=V-N0+8Va z?z>Z0e@iPZ{30OCg+hGnmise4?j1h3{UOphi2lD56*7Tafp7%TyXCpSCvsC z4?}~Qm>`@WQ~BY%Xgk>$2m`S5v(Auqng9Lumby#71>m?(Oe`u!a>QHhd+*XEqb|CM zjGLxz!fAMXFZQ0AU%l>l?ut*?XJX0g=0BV%R3ub!;y6?9lym04`YbIYBg6gYDqR(f zPt=wA*fGRFW6rua5 zU*Bk~!Pahjo6JBlZV9(QsqD2Guc5$Wyi-Q|&Lw^NhdC$EZ5{e&(=FGd^Q&C?EB1Hp z+fc$lcgc1d`ndip|20aTW0W@@)Yd7-o^>KZkBr8j)2D8)OSQ{c62pMFBOPkMMgHo8 z9LtaN{F2u}2v2v@!c19DTsQg&xSshP{gHTtaZh*A2v{Ihi=PnD5gbxM!TE8Y>B4m# zj$6e63eltUbIh5kkmU1*VGpXygT-1b4RSaL;`DwAN=?jr1Y4mpwm`l+?FmpmRDdg8 zTLWIz*EIYm4W1dyW}n~UMu@R)&$Aa)jK}8!sM2sX-A5*IHgG7IZ-+NLZjFc1E~l&8Go^o%+smkF!V#Z^8blFu zBRt0!wPPq&$$3M*)($tg+(

T5nN<&!V+~6+p801haWP-yKb+i%@u{{V!&%X%6J%Bx8*4A^w}MpqA?JYLYSsN{Z-Y>|pcR8q1&_p5GJv&o!#W zMYe#7-2F|hfNw_xjoY;2c@(AS+^?&bDV6h~P8ihn=WxU@bRjf4pWg42_flx7swUO_ zbZ*mE-o`M?90)_64!dL16-~vA!pHfjG%l&0Gjc%#*FW3xH)2XYd=8Tm0f#dMt!37@ z$pbxNGpYa+ZmSD+U#@6iPzcLR5i#p&lSXUE5neS5Vr8EZ9;dMyOGv@*IQDUAB;o9f ztDNfw9+cHa2?9Lse>MA9^O3Wvh_$}(*&>{L2@YPHC{hdqgsx&ci>Ie&pWbFLuhbR^ z;_>;kao8-Gib|gQ>T3``2)!!2u?vvDBodWgCcfeVodNp9n z!^?p{J?a70CK1}XJ5`RJq|=Mc>Vx(vcm-0|+vBi&=A0jb>7%wou^~n121)w%?6l&qRvV9$?x?;M0mTJRM8Sm4f$5)c)cb3v&`5=Ne4pM&c8i9-zwPFL zckl8@GEsYfz4l=v)dh3qqTXwP*hsI^1>NsnhWy9-Wz1x|-`$5j{JkPhZM(y~sOc8V zLXN3-(+f2HUMl~nH3$a{>m-a?{U*P-8ecQL!48$)^Y1Spj+ia zN*>q(_hl3)l1v0+2cdKN6S@5qE)3)cyc56uEj2Ge=Vr9y!cy}FeX z_y&;s&m!c1aaqB3$dK@vnZq zpG^3S44@GexR1eF@N^;9CR5ipQwH~AvQ+Z)%=o2oyl(3a6jWID_2$FmD^gHcrWKhh z3B)hA3UjOwa7RP0-l8H88m_xD>r`a^$W+azEYt3LU*L55$9rsxCBdpR#HfIqule#AL)J@) zqAPY21v7b4$mF6vnkm)(N+8U4p*x(TL5Z>%H#3mV1z2%A0i8zYYO&DZLi(=~5Q;48 zpW0d0i4{8Qp4wBHa110dCx-0Km=U=8JwbUde38PtGgct`LXFTM4E22Bgn zX`&E@CJvV$$SR#G1cor&`EI(oAqniL919qJWg;H@sMdlEmqQQp(8&Ia9*>2ksT=qz z)25x!IIH|APs|n^wkwxW%t87S{h>L!TY-4T#CN2y+A^FSKwtdI1opG>>J9W}t2!>@ z6za6I(OI5=TQ6AP=)8WVk?3{b1ve%sA5L`-3f8=NuICTH4BLP;$h`|e7VFZevK;!N zC3ywxGaQs;0hdQFg(NJU+GSh@}8jF@ZYm_&oS~*SX9{aXNd`!$8yL%$) zQ~A5YP9#<&=SJl^->Nk4YV`&$jb~cyu4bl|7evyhmx^!rfZK&}lo%}*m@sI)PuB4i zas-@=cBuLTvd%WSi_=@(gxtOgrAV^iQu1KWx`m-=nF0m84P-cY_XZZJJWyjv@=%s_LHew-pa;+4zaMm z5S|rF>~Gln0=?;DVqQ;LE4Jk#Lgh7Y`*BNh_4RMm*WtWGGuEyV*{D7Rx-Ka&m!@f+ zvMDsTo`CBOBHLf~k$QVwb9QmKglt}BL$$v6!k!0R@w7pu&ab}Dpo+%a#+bz97H~&q zjn>MWcdK5AXe_b&8Dmq4k!yX|9Q$M#frt3GPq@tM+Z2CNvtgLg0!x3krz?$ft7E*L zcp00-&?*X?`mi2iv)$ulk6?D*z4IUmEej=;)BnV&k=m7N`@6#&HR?2NNHAyLC%SjX zEw2)Z1Oibi!;z0ATJg^`jcRZ2cmV`;47wwdSMt=>sO*3hbY>I zANZ!eCqqW3$fY+H+nwh?(eYg9a~@N~>0?nr17IUman5Ihc=&X0mUov6q;MwH8WHP| z^x(V@tYKnar}$75;(+`n<0_Q`(I$xJfmoZ&VJVG>a!Kf{*S#)D($<7gm5F9w0{t@r zX735>th?9dy)o>`^h%$;zw2wF)uyzn_#x_?6lbob9S|$@sw_bl)c@7?x)kb&R`h13 zF$|^epnU0IKyGYE&3}cYDe8Vj@m&1V8{dB9&S8x6;df(%=V=I-!+AVl)7|l1Oy6`7V z6hE1W(M8{`a>Ws7aIhl@l4F*WHdI(*?}v zKU6^kD*M*^63a`IE>UXDihgp=%%A+mA%##1MG7P~Re0Gd(5h`l$R)|Jm!*sf2D(h7 zq@1SLzg1*2fGMF3gdj#B85z${A(#eTM#?8L$S2w<@2SOASt8D(&5+5Wz`_~8$itDd z1?G`<71lrAVGF4ZezpR3C!p!Zlc&U51w2z4Kn|n05Hy?QD1TdW^c0Gn>eQc#hPLTE z=_E>QT*wkMe(q!1zq&sY@EhM(grg*Mq7LR`HlFXm^sY8Cb2-KE&Du)m=p~Ig!Q7G8 z+7hCg7Cat=4=8h&mWznr*WBq8_BfDElmYdqMk&_78x`WM79lwzC|a9Ihl+XD_0onH zay*gRy**1-NG*7iQyuj3p^{K+&3paYg7-)wQTv7R7j>I?bAxZ9ZWtNa{0?6L8^geY zF0|?5xjV%^%0qwhdW5$V`f}MwEfb%>RUrv#`hkskI0Z98a{Kcd8y~t(2sD(nP9gCx zZtYa(WjiY_`IeK2c5Ur(lTph9ntv}xd>z)V!G%7jOchsBY_d&Oxw@Gqp(}9ijf=qK zh5<=Z^fCzA#u9M~jJco3qsWEu%HRF0QLJR7f8Q2HD7F{2DcX)ee%)Uu8Bc2XXXjE= zCNnsC>+&oa8Lsv9W&h+pix$%|(k2lvW4e)-bq6rrMQ{o6X1*xVtKGD7aPwl{eu}Ng zREy_QU9O zImz^%bIIp^&efN=+4ZiMQ)b*gdBykq>v_U+RP^}K8+_7uN|HaFyy)KrYh^>cwcA5g z`R%4KK=E3~p-n^L|L<%}A?1IZ#rbq>jvZ(@l$4ZHTK&FQ`r%cvCN?uw=T9C({>w&@ zKkoA_9UBYhvL~nc)yCX?M7~%#zLilqoJM3+%zK;oUd#Td)9BMxcW~w~FDCzkmNs#_ z%gB#n&R2#1jV6oFX=<46|-)fGCURw~xWd*a|b8qWZ%oAij5jdO_{ihaD4EAN4=g zp)q(c(EF4SEXVB}&37>-FEfNyTc1;Jf2nf5%XutL2LHoA2lj4i{eSo1MH z5u-H`F;%3OKQFHkZ?AQ%lUR%qi&!;fU<#9JdGFZy>?Ux5R@{E0yngyb-4E-^m%K$W z8;{`o)TikzA`c?BCC_a0S#2G58dt<>1+e3UHrS^`iEAR286W z+Ba6ZsqgtrDorac3wU~1-K#C z-rmo!8P_(7oIqUUK8=9bWDZ|0|8}&(E3ucWJr8KSgcXp(RQGvLTZk)}2UU0d=fr&7 z9oP6b@ajQ-O%zP?6&!8+4@l-~#{%5e!{mKB+BD*WvPE)RA*uOH@l=loKiJd_Lqk0& z*W|zgNXKQ7QeA4mH1t&3G>O3<={9Li$Kc2}Vpk=|o0Wg%;r{U92oP~cd_6O0^i)za z+-C_^5h6K_m49w<%rM59J}f^x@Erz*rW`(dSQfpzeL$Lxt)|!?7k$2$lb5xS9k@&1 z5>qYq=jY*dY_omYNH(uPL`sB+q~mC=gl?5dkbCt>j9#66x&c|YknYd-1VGZaHY{UO zpaZM~UXin#TcR^n3bjsB?o5^GGvL9tHzcQtA#^RSTp*&pG{49M!u`_<5d>}fC0a=i zr=Sj=Hz5h85IuBlhc^Qki=#iZeA(u!G(zUY2Lwz;;h%|Ld7#yiAGH+!anEiGx|Zj zUswRG$bSEATRm|eRlV<~S*k?^J0akbo~3x3RZq@s*4?q)=Qdg`EADr-BAezI4x_d9 zSF2Nhi5pK+dmH|NglA>GOa9M_8+w*6wscRKxLM!7a%kKxhq;2zH#{&)bUsyhvqLEY z{pw7uH8v2pYz)h!&2)32GWNi~xd2XhX7)s}3a!K358&oSUhR$y1HNk$08tXG^O>}* zhBKMFi6o_@QD9U&mJ+v>EpNu~ECIL08PB5hj!kzTPnABg=-zO|SgIzeWV8=Q6MfqL zO{=N7%-iI0UCer;r~8fPGQjFzJ-?qBVm;N40Qj@_Q7-mj%Emqa48wq8p@R9e;q8V0=3jq0 zr4)p2h16JxWQFJ9lB+{99GR9L_;hR^=Qvfazx6vb7_#9HM26L~Ca+@zfFWu{M!!6Y z(*`Ymv#&JOxIS#F0zGQ4M^W5p z0B@6=BtuwKc2p28gPbmXHze}wDEjozLz6YZ92%8-$rEql_fd-$d-c+@3J;;Sq&F5Q zmW>tuh3^>ZZ%P$84FVDN1YJqF&3kx^%y}wAo>!0z_+SvnV2fQnL00www2K>R`|WkA zK|GCY4Cc>)WLCzaqhxq55V}UIOr>_s=<(Kec9@m^W|cD*6CgK6CK z*a~W&K9wWu&hqg|hs5zlW~o-OziZSIvA9Mj>V%QqF9{Mdn4(yo#ftZQn#|}f-R#V# z#s1i7&9Qdt{@{TGK9y|HIHs|g>k(`PPu~)_O%SHK2OfZ6%#~gm`m+=E&BtaM|DSd5>CTXXZagjQmseqpq^D(p@WD@Eu$YQgl}j zP=p~wc~F1tCY_R3`-{Z}T(oyMvaR11g;!u?m!>AgVS!n6i>ZR?Z;x_D7+6Rgl3I(- zIw?kOv(>k=D!{Yna<)Jq6rf$srk;;6UU(ycR+`!HLH}aD4(x~j8Cs*kNdF;Q3QaNx zX2G|$03U41&lw|ktfuN-gzS3SF^7hzuc&nq2vq{YJJxl z2Ib6fRD1?$lC@=C;0q5UTru@|>WlNMp<$-Gp%OUkX`GgtbX)vjhL=if(sPOg;Spv; zkHr(hMU-MXjZg#;CzOcuhglIG-=23jQcqbwAt^cYI@5x#<;jGKcp^DW$#=2qndiYj zsbhG$Sq#na?gcO3`!VelFc#D!XMI$1zgA9vKj2tRI1emn6DAvm1^-6Ciowu1^2R~ z_+v!k@-pDjD@WL#rZiu8hE6Ymw8dxwV&BWCR|%b;G;bGBfj$i;lvg={_n$$oDDc0z*#2FT2uVh*Ouxx1Zj=hc@|fq0 zzFF%+Z}x}e0yzsCEC$;g4j#@`at|fMdyrV{Hd! zj#4ND#cWngDG&c7qhK<%Itk6)g?>3FH3jyPr@b2G%FhP%>XfAk>Mie$N=}ly@uv$d zeSuhneV%dq8HBbs4YH1)#9=aO5(F(!02Rm~HFKW94V%2iekm~MFu7up3z`@b6dj$M zNbZamxoi%S?0)Au2F)dw;QpawsJJGkvfC0PkdMs6jY8D!Y~4lFZ3|)l)L5dM zVfB9P*#lp+*>D|o!(~?L2`MCC;bF4(ey9_e$nWQyaK_+9o{%rrmEuWjKPk*P(UsxV z4P#fwYieKLl8}oBENrqc$ESmMtccP?piO)*RuJayl4Jpl40#o4WaPkZA_XL?^2n?R z05#11o$X9@g9Cv8#2>q>D@x5S0FFuWRw%@cuGw<*P3!En$NVhY&YsqUxFR*l}Ln>N3{;ezff$R7#aX4_Wk z^?qS#c<7{ksmi;dUkm?yKO)wP(A9b`nq1l;->AT~paa7N4S^5YXHUx13ZNnJV~GWV z*tdI#T5(o;Gu5CS0{S=ypu1Bm(}hAo&xHb3HgD&Gu!Bo3(+X$+eDpPO*-f#fZ%_z0 z8Wi#mZ-W%e+*NoCnx@9glS(ECm}_-9TCxT{ZaKCU(6$A4w)y{bCL39kTJsZg>)bKl zjc<0Qu;84-n08#KGeYYg9-Y?)%wVRd@6<0AoIzx6*4&<0Dk@-JeSO{kGoDL|#NMr$ z*fzvKQveONF(0#5s_Cs(sLTBOn~`CKk_~GNOMGxNnLs}12%G9=S?@rJ{Ki!_o}L$$ zl4C2Q)-mkMz7E_mpFs0m{;%VYTz*&1F2}MSYZ#zetlUETUkBpNmYz~hJ(lrTjv!3C z2X5&F?mtV^FJ)(Wb;!jVNKtuHWMbP7dxHlX*Z-aGsyhdTx_kni;->|l)#*X)GiVYR>f^QPM1F8ZsoBz z^VmJv(FewcghBifUyJ4;&i{)j@(etU9e=K6zs;=2P&voUajT&Rc z{>e#`lMOgoGCQU|yJAQP)qdk%iwPWN`pwS6SmL#}etL{AwBVutwU@;}X-*@`^;IBeptUgTb0d`C6C zj`!42Q^)Z`Kw?_7#BEV9Mryo}^;v)kCx5{>}S&q8*eUd#6NB(D;fsxivMU(c^ZX)%y zF{phU#Pij2>yGtqGg}8Fd%I~pNPjA}D$RB!VyCBsI)M2~)Si{SMi=uYm)r0~lVl7o z;tG2^)zbbl=pqVwbSJNDH50gk)BV`H38D{L05gyAu#x@zS_J1rQpeaGD+T8|C3$;% z^Qmf)pQ+!i%ZK&;c#@Z;gtGnXpO1ZmjHO4L)j~*abOr5jT4EmPYb@SI7Yke>sKl;Z z-6LLc$Rn`vOdVfN*G0a)I6~pbUIYb0+%y{(~y#WK&u0&rADlWBj$o$<76A9zoo0y zH;RxK#@uvdGW9M=l8)VDyq8{g5<1`r82O!%UkUh8UDXSq3`3*RqYe{_rH^?Cq|bzE z%;=Ru4Solgi_T#9taOAjX?~0IT1&g3Af)a^rHCNm4|((soaLJb*a7!RVYQd4^LR1n z@_L%>p+94gJ^8*=vxJ5+kbm|oH$tvtD4P1WP+W|ac9Kx?k@Wid`rnYM$3|)T+i!7h zX5Iq${q9HXk<=WdVWV~u$L@g=xs#nF!j?=@ z-ghPF9rVdyicoYbAVuZOuSJ(ByxgZxd<pDt*D(AygUOF!Q`di@-{}T8XH?0ed)(9IuJZR;f z@10H+9h&In9E81{@xAS^yFuw=>X<=y1P<`4qfYFx$(iYlkOgfi)vpSPSgiys-zAk% zdD(;>j(e#Fd{b8=C-bFP*Jq@^NFGIUw#3ER`pns$9dt*L;3J|zsvK#mK8qL&e7m8X zV4I-~+MO*p0<4dq``@H9mh~(9*_!6bLKQi|g1CZFvvVUO0pG_?m5x|6A;zAa(1lv_ zr;?5*+ZAnYp2xWx!|5_96pJsX#ixgGj4|rZ1gmhe4@&O|EFLJA&2@g4>!IoK-m$KG zEvPK~@f<_E_&#!G)T^eTC5JSad#VtJ}%cNOW$OkN{$+IR4yTj2&RCFrumu!!0u}(`pSr)i+&d)qNxp-e1Ao+}xyqEBF~p2ZKVA9OZiOUrr+O zv+ulxU*;!jtu29o1J4jRw43BcZ4P3yNCLOy1)~eoT!S$`XT4-Ru~B7GEvk?aCz^S7 zBBSM@^B@ppPGaGV|Vf4|Oueq0TR5-67l#|^nyE+$JVCtF_TjyhxF zw5W?0x2x4iNTEF5pGv})$8ela0!{V#uCNaHxuR93c*}`+`!_lXpr~`)NYt~8GmPj zq1T<0NaR;F@-o9eG;cRc-j?H7^VtPzo)kolqmE@--2w%*w>hJot)ySi)Ak)enbSRw zA~f7_2t>I*mj$x1`M{{pPrEeOh~w%`%HP|tbNcALqLWKQksbc0&-LX#|j0TAlF?KyE`MVa-Gh1A#x)pR@0|XFl ziL5$z)pr?qdEe=L&W(ZVNrD)e(luJG!DTy)_vjLXuCb~IP-H@1*Lj}4GfS2w#Iny|rujR0M+^Nn5Mi%3 zzIweFEEs5y7%*^3dtfkku-wz;GkC#I;JUJ^Bh;Udh zikSew_oHg*sH)aDz9w>U25m6Fvs6d!Ec;`7Z0zSD&K1eX4+>wbHth5ovy|gse*2lQ zz{KgLro*&wIS?qCb9F$*7mXz-9PU-G$0YQ-#DJf|Xq`PD2?uPU5edp}G&tJTRdFOp%t-BVT zbNke~!_5Eh{Q8$RJjNAJ4ILM)WUko1cIeAq^Tg2IV>JtFR7H7AF+u5Cb7iTFW36>r zYDKSl19`toO7s>@yPr}u>OBJ5JU%qG-yGYAX!Bq_)UdFhP?$avo5!FYw!+#Hh7-x{DAGx$Uh2$! zcPGOTmCl3P(JjX+3S*1PVjIJA)1)!{5vQ{ETnvMsACPN8V3`+&uFVPCksMA<9!X$% z&H}>#g)TSNMfuw95pV7st69pez6|Ck*;pd66r|=b>-0ZrNWfxW2(mV~oVVg+S&n(L zR&9ZT4oah`7<9?z*#WT3@Taf`lOb$sWjI7`f)eMS8I-iXJTGcZ!0jjhjz1H-?!tSL zDLg_rW{LOg?8K+w?ay=@ky3+U=`F>l!|WBLZe0)ZJ}-JwQC#E~efApZ4HdWZYacea z$=?(OCNl@oDtec_n!#DeULnrZCikpw8CR_E0_+#h9+dBbxb25*4||4YM#-GI$9of5 zQRInrR~au#dv+0l5SiggzZO-URZov?4@5B7=z^SMDpOqOM1->m)ln~Q>CHq$E}!*Ix%{uP?uR}%p55+!qXMm zC&eN3#&P{0<09Y2st$8Xc9XY+FbDfDWh69jjRFk|(8#_s5JuBd6q5-Eyj^xZgS}D} z`w284m^>DqYgd0X@WIhNnN53!>IaIi63^~#51_2F`k3mJ_`{z67n~oOd0_W> z|Jth(z0p&!XM$M~ZUm`P>9_Q#Px;LS(UXIT%!XuA8R%c3%4+xn`#s>0tMIQ~MSsPe zHry&Z!nWcI(hR-)!PlE|NyTqMC)@7L34Cje$gMJ$x_K767c8bTA=%J@ro$%e6|t&K zr$hW7K@Te*C_oh z?x^76%nQy{Rrxa_aDS$9OsAyLOz%ByH21=G$eWw?F3JyR8&4GFsP483s|DVTW_I~Z zBO_H?nkX&`)H3CLqu~oGoJI2mvvY>TXBNnd3!5aZFv3D zNd!>5J*vp`=@4QEBC{R*LF~6`j{I?rma`|7)PYx}%z8jYzd&%cBsmJv(H%REm4HaEj3yFDBhg?xaRRqg8E0(Fg_ zzKw`=U;C5|J3G_O&&)o}NL=Wh>Mz8+oc4v0x3F|VQr92Kz64o}7PuW>s8H7?@TppQ z=W;areK&zaZ0O324U3!mef#|>SOHNmli%|U5#8}D;tes+$96=K4j^tQM1#H0^@f(c zz{}`G$^er4cR+u9wGv7Sd!D&S@mBRo@#j7+d-cp0;@a?l?3i`>F>CExNJBrAHY4G70f1?!TRpet3! z2G>pW+yZj|f;~_}hgX?Vy)ZkPF6m$7)W#%%^e_rjGS6Y=2_HNdX!F#(HH)F^5{sP} zwOr1>rFm>(-_;NqSZLY`kfZH3EQ4IwRx5Fxu1EQ3b2kbhC%ONwNXkVk-|I{-nBSzi z_P8A7z_3=hzc5%z!Jq!^`Qm`k=!Nk6hTK$=lbbH&&G!3}MrcuQxCO>L*wszxl@aUAl8gjc^^ISItIpE-;P`8S7Ww39Dg=4n1wjPN9$ugHpoEY~*y z{smaY88#|r=|}_)a+}wIy&w?S?Im*Vf`JKFgSqvv zU!9#Cc_n!0=TW;Fwj*o_5QV(45xn zpa%H{WiXWOnxD%P9vnPE31`*9cG%;%d;ubjh(QNi+4=PWlr_OW%s!caiCcRu zxU%Okl*7Ny+NO=$Qpb%o6qS=!r8b@@vDCB%hwsi2EbqR?R}eY zZW zCnAiZm!q^T~0`s6a3n95t{9h>KKP#;c;_O#g9opj0-I@+wMsG2`D1NE^>KX=eS3ez7VDNc3q(qq0#%ZpHGUc zz8~b=VAauc83JO1S$@>g_m~17ChVQKRICA7{JdK2FjMdB+ajPb;$}QGnfS8GNr}h5g)7ka{ ziqqRt;ig}Q;Sz(t{Jj2(nd@ei&A7w5Wi?+5wq{~Zu5NDP>Yvdp0RAy*YOoa0tTrio zH8$?V*hTh_H@g*u&-V*uQokx^HH*+*w-kK2x!%luDGcSx6lc^2W}kf$%C=!?R?-6y zhmyu=N3t_nMw_!!`SF`KR;li_>TedRk2PeIdjb$$+PCZa0;yV8QJe9Gw%c#jVGMHw zOQz+Y<>t^d8U0tk9_6S4)9L3!382$55DC`q>^Rgs?-pxH$1zX-^7Iv?xf!SIu^Tk# zcc+Z;t|C&wG^%Lnm=3bJ9S5jaMfhi2{_q<*!Ny-8HQ=94eJW4z`4YGK5*zD*QM(?! zavq~gUc>uyfCEsVJnz?Ve+S5bRr+eTFj2)W`746{6#7KRnHcl5Mu#~eW7_d&Ky@bO z|95{RzlYozpJS@h4~PXtHVHNzpeO2d$WE6o;x`2Lsk}rd~konp@0}G zD2cqNmRLx}TJ-9*T;jwbD-CQ9?lj3dn20oOJ9|xJ9|(Hs3-(3aBeIWHGCWjkMa?Iq z3Fcl;*AgHOF1l(|XH8X4PQg5T1!%Ds@4P^F;G3b6X~Z=p$T6dj&n)U(zona*3oLekNPKxJ(D)d$!} zFha6_HZZ;_DsVqKVzymssCbeOf6Ql~%V?8Cr^P1!b@}|FuzIRFY(gN=GVDMW0H0#B z9b70{WhV=OkXU{?WHrY9uem|;VoDA>wlgqdp^Xw6qGhS|;;*i6bp6|!Xuy!h9g3V) zm01qi+?bj}<8uw{@els4QS|N!x+8&@Mi;5(h$yrM}_lI=bAj}neSqYKJoWj5@L3!sZxmJ^OSLiwD- zE#c9PBj&*haG0c!F*S`}kxQ9}FSEwAn%t{g9MYB>jG1Z8LM9TaC9_}QI35lHZ9O>r z5ZW{>BkBQEps7Cd<{OSuthV zc727}M0)vKozNxR{sB>qfLLQnjjrFD$m^V?u%*`+mnxIa0KR|jk-;LMnFdTz|6@mTJ?(nKZB6eln3FrNgdvWlHJf7=LDW`YE^~#LX{m+<)BS=&t zH*kthstSYLyF&93K+~@`vO=S{IVVkGSOmgiNCkBV{hJShU!SfSGas&9pc2O-;+g&x z$OY|lb$P_E?{QvkMZydqhl?ALL)cWKP*Z6LgR43Y%WE$pKiE$!NR}eiP^2cI=mfNZbE%&`|KO`TVTwrt~FFGjc+9!5>CsoS1nCaC1s8LTw zBvWGZ0;hYbc?yrqjwGX`1Ofl&tm*LYQ-)R4i{CUnrgIeGMreCn*+J1Aw`vl=Q&%IH&ySRzuFw)t#qyg9S zFhDo8s(VN>!$On44doOpg>hfe_-Gmrs=p*|yzDx|$EtE)ajvnSZhtk#WV@;{Za0GT zMAdfR`Ta53#Ms?d&%*ZHnL#*}Tckqff4gdkV=h7PDKUvoQ`uK*io1P}e!27rKk(WV z#8Y+xD-N`|YNeh`i-nuu|Ezo=g`mbpUTqy6Vpl9TPl($i)8FKqH&4`IU3g2k!p!Je zUzrcIs61=#bJh%No5_TGXZY!ipYY)H#&Owwx9hbSjF`VtGDu%rwO)jk-e;$i)O0l?m6!zWWUzy!OU8nVi<<@yLV(clfiFDm`#YG%9_+4a zeO0w!TC3*BZcTDfUveqMe6kX=JZs^zf>FFMc`m9Hc%E^DWs=B$oc1VBT&YpinY31jRXjH!p zODixtn2r84DU-*AfA*8PGY*mz|D2JLk(!YxUc`#9MKS*slzB(qm&c;Hlvuj=I zK0Ue`5#jT?fkPu{@^NyxFNJ?1qbPqMHjGK6{Mm&ObjA>4wc=jAsskNh1vkqR$9cqE z!%-RMwHM;aE(1|&uXtjgMhOqLnMJaa6SW&lLXNB^e*M5!YH+g1gtaThF;B@yWH+6C zNQo*|D%OV3rLn@D7M*w8ZQoZJJ}~k9t|fdu^zpxsiQ#)NQ_si7=X-+PAblVHll{}` ztBScfHF;H(Onk$G&XF3)w7l2wvDtRbL=lN?(Dwrh(R5r;D69UZ#s#^W!Z-4r%D-aWb&+Mr5c)teg%Rt(~Nw zKZwb^hD&KChG+$l}I&61$SPA`S5iTv6FL&}9ks8Ni$0lPpk>M-}lYXi8!xhaZs9)VcN9 z6=>a0UKMI!@wub3vT-ZDvMhh`;srWh@4{UgZBtKWnhZ+?`Ah9iY^KV;AoR2}6bCsK z5hb3Cx&asOvluN0nZ$^=N6KBJWLjTA|_Xs8+P6#ijkL1aISynRV5Q@sRl{)2u&!HUr&=sZNE|OXho~=fKNWI zZp!U`jJ4atuBt`DOn?UdVRl=~TX4#uk1am2Cl!w+j)A5BUEVp*oUDKa2eq~eKIKI| ze2NmqBMiFK0z%0djbAUveA%q0hdsHpKIo~7Cw@h16h7udjP8x^Twq6(S~6gb+vC z^2ZFmoWd{TnRTuc{JtA{^Ech)YUS20{^Znny1{{kgN~5mkuih) zL)d~+g-3q8(XalMmvitKjm}HH{&H57f1LJ3nb*AcC)DF^>*yqpHI>8lr4rnCS55Dd z*!eNv%fl_*qa2eXhs^^4MjX`1LiQBVkFDz{AoKa_4!N?)r?_ySv~+iO?{vopF5b9$ z;QjeTO8wD#_7@Ux)7Frd?o7D-WWNR(pBB&*s9pK`OEIC8LT)y zBa199m&KdpCi~@X=?2GjCJ;g#?(7coa5NwTBfVMrL=qs?#@2`Kc`WOhx9qcjeCf!L zqN`nMLo{JB$IF*5nGD;#n{46|If#G1JNmJya?;Ms8N|%MfPCwJZxE+dUIC9rNN0R- zXj?$>qfCbh`~Q7qgHx|P?N3ej7Pbqc1^;zIQ9@&~yT2_Zcb!=n5e3GcBj9No(uw1D z!j@XC!u!bONa{)n5cs%0>kv{Jjh!$nJmx|hmeQbKgvwa(M!!+KVBo!*o^H_$jG$7d zVf>*^6M%CXij0wTO6ccilNI1daSGMTvEEuLNUxfvkJ5eQ=)}*Z5T2f7Bo1cjI$ge8 z^6mVS^N&J^`yW&45Mfk^%U|g)$p-hI)^Z*>E*U(d+pJG-8YJ&*n60NbZ0@D=;k3g> z2EK!kiY-7%mc{{HUU-9=`=5C+AdjsDZ_B@66Y7$}q9g%(jZ|OtGSHJfRW8gh*td({ zav`TBa(;OFr(-GkZ!@rUENf~x*TW&c@FOz095r_&hG_9~D6!nAP`820taDnAmEk%& zy|cq1eT6D4kWWqpn2IoJ__UbNOh)GCfYa*&SvqB*I1!2~5xuYA;q<2#MIKcn*Vwm2elv?~kw*9xMkJg={jO|NNxPbZ+o;U&RWUkV%aNtY_A`R6 z&7oB|QAFKxS)14Uq6qP8Uc~kuv@Npt>E_YfJ=DIobgYZ_%p$XqQp}wE0gHMqyC;OR zw`a|Z>r||>>$m)bC+oh_FpZ08j8|&2$CrPG&~kl?(#~^k^28hbfK$VfXbfL0-s7Jk zYrPO=La{3_ONp+LF6jS${%`?O)h+0isf2@Npx{*3`ggm!oBMZVR0>`zk^t*YgR$G2 zzW7XK<+{2EM(_3aIkL;lJmv(DC69l)Z3xPtfEka3ZY zN~F^+B$u&S?wBU|Nm3nIEcMw}P8e2`PO2othIAie_x1c0Sd%~W4>B9o@sxjVfVu9^ zn($jknnN7110|J=WXtBSR)_;Qh)EWv{k=VYMrtx2@Gsai#W1F`DtZeMPEm)>Bzp3A zZ|1bDjdvr(H`WmLZ=kk+(WPes&n|vlgj2B2s^Z1>`+%B|WL*3-BkIN6Zwnk|&9?#O z%@S@>;7if|`Vqv1;xcGgzU`tJHcoFKC|mZf@w&9#U7&axdGUYE-3_@}hJvdn4 z<#ITAt!z>j#K7;>FqOY8hHX^F?K7)E{!q(RyFdTF&9Ky`RCJKUA)_~a8TJ(QCY!~j zl)v2dapis_>14Sra8e?SZNA9`!9%+4!z>X_rQa0h^2(UErm*KnsF);Nl>hEPRZXwv z*A&=a9=DZ<^Ag2+jIkm)U}dj_&W6qFFlGt<*2Y{k>rasPmMJEltM`6=vT=ifc9kx= z|KK~2If-D`^uY6Wh~)Fi{oHb4s%}8f_@aa`KYlm)%6_um*?ZPP>Omi%b7+Nvc|#WX8Z>m_04C#0D|PzlnHX7+B) zpt=8v7b&kus18ESuCA^$BC#Ww+G2X+-Agq#!@(^RAh)~06_BCP!;0gDbQDlABu(D0 zNK`;cA2;tcDWC>5L!)qvFtZ=Or1Kq6$IbOq$Lx?m?c!cKcZbo?Hb)gyMZr2nuxWB_COKJD&y(1f8xpLkYvQb)g`&iV7js0ClL0qTx z=za+@iv%<#8*OzuZ6@qDmZ|Z(u_W^h_1A^GV&U4QB9fv9PjOeC?|wV=j-Xw66im=y z4_1;np6zBZ8KuL~^X}ZU(I>d@sZ`+_$^SMTC`)%#6%Sop6Fy5+`R~)BJ6zo1j9V_E zkrvYX54AD|>!D12=o2I?_=uAnCYiHZa}+FuC~#MVCq175i;I-@AbhA@JrMH8o%~ORu<0~6W`(7PkDh59>Z>YS%eCmX;ukb2F2=@C|9mLV! zi$dK9J)Sk}UEp0@??_#se=wes>rNU&8g~`1sob7F*fq=3iypNj%w#%Ay2Q2HdRu93ys;hlv%=#Co>G!;Da+%HGE z7U_F@_zTOJUEf`be3MPr8e)Gh+hAFkISaMu!OC;Ca`gx5_*3E51G%HcM!S{ChUl1O z!KRdv7}Fe~=j!wH(lA5fQqdSAq8f*xwOE3gWqziig(!U1QZeK%CkRa#5j|8EY1^s3 zGp*XZuq{bCp2sgQ2_Vy7k(E}2M?Q@(xh7NsM( zaz8(OxQ9`|)p4cW+}Y!c*s!HSp)=%%h+j)zFEFp2dO=Tuj@X>YdbcX ztoQguZt7(i{wKV8wnTw6fWrLyl_Jr}Q-AC_3VC=nW zhYZp7?PgVJbxwE0qjeN9BpN}!3mNZr(YC)XvhIr{HV(w|!l9U5#PUJS5D&W!dzI78 zkk|Y3L>V&v`vyIR)nh!G?(Xl`SzqSe0r8vA`iMM?>av;H$eS}3 zy7$AksV_FwKZLg|)oK051Z3tnjW80Z;v-$`DIEvKByH+)b8`#T#y@8=lRJPOXc8I> z)wX{h@8#~&7}b^_a(OI7mocIcHV;i;BzlYT+Vyn*zZm|__k1%A%@r$bl+PFR^*La&2Vitq3|r> zJjFj46EP?XmgtKfo5bm&V7?5aiAeXqSYN2b3e3{Yho{=faUEtQe>RxN?z2u+WI-N# z>X$yAOX8GM>5+v&!)+ncOhxXHgqzeGuW!BB#64Yw8{ZoZOxD0Q^>aLo0F>_tfFtbz zsmNe|iPf)gKrcTA+$~2PyFUC9U_p>79|}b<-t+Tz!hOML*Ot*mcBD|!T4^w2{>xe! zU&Z3YHm6?Alv+(K&bDXq+!$>0AP1jWBJd!52}c zI*%%#w?A9cRjQH~CR?A``zpEyg27FKHXrFTXZlWIkS^i$E!hrTq#Ncxp3pKD7>{WG&$G89 zf&Kil%WVYe11Osic1ISH3<)VwdudkfppY#b}Kcebbc?_I7Jkr zzu-NutgJv-K&uuHJoiQ*X;!1&9s-1AK|q`P)rJC4_So($v|wuo7L9?CFKOF}{DS$8_CS*7!~ za?{1E)rsygBPS}5u114EI5C>|wT2%biv!%2d~m}tDH3is?m^X#VAp*9mc_XDT}DZ- zDEREdRYnq!VzAQd-Zl~g@46^RNE|MZK(mS$#xSVQ2Jgd%Lo~(^ojQR7lY8U3$&9Gc z>z|5p7gvDhe(+&Su>0ozBhm_{6$+8?qP8G11Z($oAd_$Iw$+^_dK{e2h4%3nmQ^q(g<0XrOs07hzPAn%DW!z^CwkdCGZNb2iT6^E~_(%q- zcoiz_h;Pe5@eDdV-q-7?RND1SApI+#%BcH`?Wye*jh6?7qY^cIe5M(CL+t2ViA>kW zeBgxLnkY#PJAF(7ULJTP+~XwRlPLiG7Zjvl;sDJd3NUEXJ;buzl5KAd=LMz>*IG@< zV~W>VPDmq4fwZKNFG4RN7a%uL3~XA_v5r4nuZtCml?vb=Zgv^tL0mP06u1LKx7@br z==yU71$99hj6UHj@6D%F813KAeK*n}guv02xVd(EMDUae1H51(pflSJtXuL0K=OPg zMQ_mqnxFBR^-=PJ_9-9 zDvn`;>GLQ~MQ82B9sT5pGVO534&zoi9j}+T1d;M?dc{bSda1FIyHxD+i~62ojRT2E zd3=Qij^ZO_R*ki`-5x9}56OxG^A56uF%+rqa=9mDp_dctdIo@+#8j;UOD+CKJS(oi zxFT~pUyZ*hqAzB7lgBB1u*vc#d7qaxxZNBQSM@UnVC`w?sXCHy`*6`ZgvruSj~a7@1xB;9D7%2 zeXKa|hHTK6nx^#paBk~{?BhfW?imG5l%crMiqLP0RM>g%#6`=TQ`0P78)h_pX_E&k zn?T^ue42`ukhB=o=neb+To$;!6g{IsCIUn(0^87#V5YK&Gz__N=qb$s#4%`r;$2cK z34=S3i$3F;&2qi;!SDd@qbYhd&S}W<`IK(p)JAa}Vz;=@21el!;XmSv94#^e_M3YC z<*SEIgjWw&NP=el=HqUn)n8*j75=;?HXevua`Zlwgy^F;ut$dU*X&NvMK`L~ttov@ zHDBG3KM3|v@shzGBbOwXE%1%m!zVV)j`hj$NvRCOY(V_){S*{HMd8TzIM&BbRq8Bh zD4zY5D>MYPWVkBFCR!?E8!?Bv4%zLK&Vqt*KZ-PLE7g?U!dZR%@8{3YMs7wJ2gkdW zMAW~0cz*mo(pr=~zF-L&_+wZ>InSOG%s#DIQkXXsW$?cZ?%>?ZM)GZ-T6S{3-O$ZqDjy_i5ohB@T<<;~U?>45gh#u2u{| z@5^Wy_5O*1&e{GVHWF61#46e%mdb$rzgTG)6pq3XIanA2IX|iwf3wBalRm}kHH&d2 z*|*G^nDIHX>lHP6ypXPH1pWTMRqozKJ=N!XFpl?watVp$rpbm}JL@Yf67vOwyj3Ru z+}XJKUfBcRP zM$=gWAYeX#3gYP2%;YA`$Y!Oe)T@0OlDy1ZQ@56N7C6U^P?o6{|C8h2k5$k^u#n{P z)Z@84&`@jD{wwJ*P(0j7AX)C7IZrb}r;Ay5t;J&T@x91H@PZyUw?3t26v5~nGE8_sWZb`4cEmf z7w0s1x~h3q5~YY=mkvqkxzrmq)AA=_4`vM*&S<6y_9^M%J=dH~4dThYVu=5#JD}L2 z`0)jJSgbWx#mba^mNY7FSEuw-68ExoBZma;uLKQ)3sRiixhb6*7nDePFiN$Uut6|X zMu){>dhe?mznjwT&CmUxfFO*|`TkouW`kr%V;IptE)e(-3#7deU}v9hbgQV5{@T7E z8j~RkA8gP$f`CS%5s8Cb2mKqSEsLq18EBeb!ef9z;fN7{$sH`5hxkefY%Bx^lM-?t zxmy=$a?eAXD`KsarI@Xbd%VTg8u9=Nq=H;w^eU(j>kjb}FWTIcEP8c@etaj_3jAId z{+f3Vyho-$OQ!ilKo}A3yY5rs8Gc+|PTB5kl*|Ccg|bt^W7GC4Y1sTTHw^AoAiqeM z8$645dOt4(dm6V#EtGbtkQav9DFwyF$roi1{{POrAK^qbdSb|T0r4*%-vW#1Bs($W zyGi?8R2FwjPw#V^;9eIFIbx}26}A%9@HmRUmMgyNG{Y%>yMepHU?Dxx-|aq4(y%O{^3w&Z9m9^zcQ0%(Y;fSGTigVCTWkCd(t>}9rj z0n&^+4(A(iZ}}BFt*PRj5mnq z4Btcs7JNCm?Y>=@`T#)UW(soWh@phl+4P##!{Kr<7f8TvYIK74hqaLVKv*B<+UhY5 zxKnP`{G=Mtxz~t@=f<(ku6d{2!qz;S+$+{nb6Yyd)^cq16rr1-fbe=d{!TudM>Str z9t2Hll`4vIC-pW+fd&0PD-0pNKXX@Np0DmXmV_(zPxkNXC?iq#4|A~^b?$Je?t`5& z-3|~(c#Fo();GC+H_S9Foaou(FE5cXyw{V}$=S{mpD}uab#SNI@+GLG$mxa!jWVTB z-1Y~Yj8yIg3iG!0LK6XSh0%jV*)n)_gdrwDOW)RsX=7txJyMDuUNH{`=UR{tSQY8wRxd zGdU-tZOW?;5krZ2HA=Y;&VinF_OdM@S0ox?z(WNw2)OFaVGo33S>cw(>5oUFtx zv%$TzF4u6=Pe2P}W;=j%1Pw4WvJ@S&Ucl$(3%C)uJkP$LR|&%UVkkscBQAz3$y4(mW zmXa}&&lrPEJfGc}$0=*?#QrH^FdtyUi#SW2?gx0Xcz??7C|bF#Ogn~pyun2`WLbk~ zHH;zP_M_SkUq_2QlXKvkl8@7;W?PuIyYn%wZQiAzc~rAFg>}b#+125y zZO6c2E#>{QlKg>@&H5P4cR&xMf%jy?hJYo!{b>XZm@@J-W0+YNOy&(#DiB`*vh7I; z4|rL{LBu5R4`M*(y1c({Q+{?EBGvrM&y(boJLl@1 zU3ux%`8V0e?8H;#-P6;1pCN(%)Bjru&$+#bV^*O5SUmT4<=ufiW`~m#l)bCz?IX|Q zp>~g^OfSn~f$T2`n7`BKbb=#s$uGpc5m4u5eiT5Aww*Ii(cOGatQk1eLK}X;-}9*z zRgA6!P;yW6gPHTbPBNEf@CHLE+wMheIS+K_qasXB!WDOG^*xs!!%~FOz6Tc+c9&GN z1TegIHc6l|8fbE~jQ;wiSg)Gx!B*+bt=qs&tULbK$jpF1uR77S-|Hilu9{WdTT5%j z;)CXX7riW?s77axwz|7@v(Xl4m}OSresqI55%x+=f~I z&$A>dr5B+24F65sFx2&5+t3V4#L($u z^O@fx=zTuu2-ZF4{%Ix}sAXbrzVzZFgZeX4-y;UTDTM?cxBI$0`v;=$%2ntwq*@Wb z*`@jAtj365Tf6kVGrT@E?U^rbuCkcW@J`~2|C*0f|C>0>=a=jrXx|2 z5M_zDpDtI;^BRK3Xty!!?lTXoQoTTE2pJO-dVeZo>Yt}s2W@c>5qsbTuwoB}PC|Z! zX@^5w!>PT%3kk9=`^1C!nrY&fub9-MOFQ*Eo-(y-BfUeTSi@&P3I9y zb=XmIGB?|tR0fiXGja7dC%Z!E9DCn}9@W>_I`>{gpH^D;u{j@(A?Q7RMD1fuZiCIx zmMv6-wSBp;Ia+R8eVhdF1yWG^7fR%=$bj$Uj%KLzZ{KKp2Qnoyng zZ}<#w{-<_!mmkop#eYRxi_;#@XH<+wpNn^y#^u7$bnn1JesktUenW>MAQ}IL^y?iG zDPE6Y-^ZJDjuD8|SgPoc(qzSoe;Bc%cS|3k&1*9t4VL*;TZq4p422 zc3jMSBVB8GbR8aw&S*ee^5YEa#o<8IoFpCPxm(9WUmp>Ie17;giMPpvJs@#V_q;6o?( zEoo}iJ|OJYx(kj`{xI1$W!&SpCsMCi`srE?4QV5{d#%;JOaVf=dw^s>^|r`;WyKd- zhj3Y8Eg}`z8(y3*)LmOj6sfN3p@|s#k5KU0v(L|@L_saa?#yPwx3fOnJ2H?UV1*6; zex)iIC4|5QHT>;9nGA+k}`$C{=O1n@^VyI^V93IxHN5tGaoCB2>b}z zP`dv;o*bSH$UV^=@bUHJq&`8wQ~309hCUa+Duvkq`0eK(WvW$VP>j4jaenMZ4Tnz$ zy#}+1v)O#}1~oUI-~Bdk68&+b`J!ZH$e9g=5m#X16|qjtd$(r$5jUJspW9$Kl9^fq zOjG$ie@11W)!`(3(O3&pygS%htN_z&UDEetJ_+#Z40^izxZJF{mA~DI#xU=<;Zi4} zTzkpxFBDH)v7DP-kzY&H+XEP3faJ5WlwfC!oC1h5A7}ctkI-v&#iUXQ4y{l@9#kx@9SBPiinhG+Q=y4(ZzT@w&72Ao9?B` zGj|CA`kUR?N;^Dq6gCd+xo;TgW)NX z@z7U}8w1oud3R0vZ7(2Be)s38ag9%zWguWDjl~3+!)`GOJg*egtSj-xNgHjMO_Ajm#>6K%!vK z@)>}!YsxJ?x2}CC-5le~REB6I0^VU@F#`R~8$1E;`ZzauT8okVS5GbV5^Z>kaNZSK zt>}U7yP`o&e^3+zi|ER9zS9jdJ^EiJ99x4fr^_zcB+9b-AlCw*AV&815+v z?uo8sNonJ&L;nL0C{71N`rb%pb2$i~1V)p9JRtU{j~d0w6c7*4t;7%U30O}vG4~lQ zG<#h}h{x`~eDT7WB?ca38i?NbfwZ(VK98Sq4!(Npf0MED>V{-LgQ_^S(;6v<`RvVo1V=u;J1JMh=g*BpXH|o*F8&K#AlLb09f8qI z4*Rq05Og5a&o3)LTLjnkFtA7>eSb<${0Q0_Ko;RXmOdppP5e+50WPZHAF8wvKLFQZ zy&3%~l^FCaAt~;c;+$xrVot7T903)kY=V#_wjLa-kig5D4DVVng}S4^9|cU<#v1Dd zQ*OVAM-z;?jVFL0kdKmc0A_L6n`S#&Xm_#AN-#iR}eekJRV}A8Ci`N`e zieZ~fob&^{KSW8p+BEtK+jg$@4X^9oACR09BnohPg5i@Q{?8Z?JR5+6mZT!NW-mwgN-#hbH z4=+WP-9@f%ioaX+W~l9Sq$dvYV3BZh*!qYZUZ6loCY8XIK%`vKXEr7lEsSw)F9xF; z|4Xk^^NO^P)_*mbUZpcB^rhtbn+68$Y6|o&n>zk=qXW`Y9_#0hZLD&4uu9sXs zJj?38Z!{R1@MjrHo}}%rXg(`NPliN^8*q1 zSMea1)0=;1PoY#~`ri8ZEXDaH+qb%l0?dZW|FQ&6zfO4Ng#m>lwz%$BB7R@wF+R2Y zo-fqq$A|W8`%JI@t$t@aWFykKQD_g^mZ>^?7s$l?F1eYzu$tdXEFAo-(*r7jAhYq}W{Nf zTGm+gdl?Eqc>&=$T1y|xDj2Wxk&QCZWg;cq1>c08oC!hM{J!gJ4aDEa#!aW1K5-!= zsPc^clp`4KBCw=o!$IBW-@Ib^Y)D-NwKrvBi|Kdp6S*pxv;LOaNtd^aSR&P-aPn&9 zHyM|mce)UNj9cjF)Ijk1k9}}d+T3O=wU#^>wlFrfYX*yTYv%xuk! z`*Vv_R&k9Q{PY)>bS6=d?ntz2)*JS3GS)J{f74^-F_P5Ro70)Xc59iMF(}S zLP`XryQM?AV?dBbLO>-Yhi(RGq`L%_?uN72``PEbAKq_VL(Q!JTI;S~Y*_qkdS_qj zydoj@_E>-G5Dks|G~=KIl!~LZg3hTD0X%PhfFYhxDQ>e%1bLWEo<*aYWvJmE4SmF)0RI4Y$69S{siC#++ULR+r zRg2ZHm2=atOcF_W`_9{ZfM;1JPe)`}I0jMEu07*T-OdiU+6rHN(;cmKC)DVo3AR{l ziKxR0zZgD=L}cW$qGWHbzr817H?(kMTt$`#0&Pa;{P+|L!NRS(*O}WExtz@dIV&kNk1d+pmZ9#}r2)edSZGEb_a~#0^>@A7YX^+qabMMeUv@wCtM#kQh#Y=p!n-DrVj39+LexTFzr*0I=!9l*6V%tD)e0Q=yKf5kp9k0$tP|M5$z<_ljCcQ|e`56a(|xVO~TyMnzJ zBuEo{(e2ey0oKE*#mb~c5`2JBbHRwDPt`-OOh7WQ;He379d;t;x|!c^wUY4yySd%d z9JSxop3Hhpe6Dq7(Z=)4xS13@_3ZZD{j7cf6R>wAL>$SgxSpRP?mAehD4(10w8#T{ zOFV04>xOLZB`r~a*+teXW=JTI8df8fjlcwF@WYgo)M}UMCJe&!)Nnx)m@!9{zit2@ zrnGK7to!8kf~NNi;I>mQfYAkFiQ~fYB>;#yc-cPs$6i=DDz4SItXlCbB}r-Rd~Kgy zcDhJ3lxy}Yu8I@2r3n@t1dsx^xfvki{#^jyCP0fQGuQArHe21;YIaiWC$%sq(`x}#!n{D@SMC|(fGyZ{c z3MIqc`D3Ah$fd6L;H-MNUmmX22MoWl*K=g?6#coFy44r`Cg`YbPayPC)cwr(Bp0OU z){cy(ChjzrGbX$C$xlT0x|11HjoPwqDfgD`6YM?jdg~W$*8K?Z8qJ-&Mn8ajan7!R z_Sjgva1A`|liB#0T@)vt**Z(fd0hN+Nx{+G>v{jmy{JWyG`~52eAq|6Ae8^}uyZ2B zYTLHNkG$~lye0in${e#7=E3j(d!`$~p||u57`|V>uxVi4Y_t;G%6vT8etv!n%QY)S$7t5BD#Y>ypv#BYsYkRdjk2l>kz0>6DYOwXUjm48IWwHsc zxtA)Bsdi_+E?y}}zKH#8xqa$y@HcxV+pzbRYNs1wzSZOY*{M4HNHIHYIP-iB3>Z* zadm0EIJr7X$)X1JZe9Nl2|s44;y2jctk~#z3)@kgBf0HSJAzKDop*jRyJZ1{z*zQw zDZh88^!J`mD2pUka?%Pe7L#v^Uml&ErAB$j&B6|E;Zs&{94H7h zK^cL2KWfhQ&9Lt4r1@R@{k7`+F}au?z)$^Z5t9UrzrcV;m*q#_d3%5}Ay&`ilLjE^ z@fII*lFknhgR3}D03>f)7OeoWJC(_9(Aj+pFi*hA zAT%u1dQww?(`r)soTR&e7gL;$nCDj1E0ZKT9YY-1unI#Bgmp7c!0w=DhJ*8}_WbEI z>HL#wW&2C;CW$JIK}52vvBanmq0bv|E@KAGHw51)t1=I+0lq8XY{284_2 z1|{31>CeZ z+h(aS<3~IQTeg9mgL3-2h`Q~Qqz(w|h=e30Z*-y3L~?+88sMe|-pY6V@YmfXI;TUU zA7HqN>-m|Hty`wA85W?n#$Dq)mmpUI##G5SUjZRQhC`W6`MD1jeTo`O_lKvL&qe`e zs3s_VB)kZ+B22`0d$0WCRmzu_a_mO9P^k^CKVIkXb`8HnaoSm8ID;xeGRq@BUHlFw zmN4g67%_~6@NRq}9u>Fvq2fWxhY|kmU5}!z>gu<==jOY$y)wStDqH{e1j*5*TSpsy z2L&-&28}T7ktrnu$N3inS_m}@*FcDJPieW`-1pBI)Fu?m>G)GdNYw)$=r{aPE;D+P&q|}|UAO(DWPV({-v$L{f{lT)w=$-WiT|i~ z-uFvP%t@qV8C-=~T;b$K^Zpk!*3}A^LIKJiBqAt&FRnL+;vE~lN6#rE+u9ADXtS!YKac*@78=tSl@vuD>TLvFPn3D~dgEm1*IKnzeVCd*f{*I6ReBL5Fp55T5_w`WwYYi`G1*_s zePAU%GyJ(3oeZrT&OS=PXw@I@gTb!!S{K%lrNCp%U`;d*dT2?-lHxDMfTOt!%c_G_ zZtj{|z$*H$S_yliL2l&jmakflGEDJDRX(1C-`c1rxoo^8zTakEDOmxFl&IIJ<$Eh0 zUvDh@_9tP?V;MIEDy7Us9wE#+uSpT!^f+5WHFy}cz=)w#vUIYf-4k04m7iLfh(9YK z?G(tVYf$L>pDT&9#T_b}HiX{SL&E-9-QEpk7o8boiMcfF&uIq#?BG}zgF6!%NiLYmBbB|iDWQ9Pnzwvs&jI?Tx;&6@mqr0_2O zOR_cE^h|3zIOR#Dd*bwn`w-AmdCJqjLK$SwcjmzuNFPEUc1l&mNy8sFN$wUJ&RBfQ zJ`ra?^jZXfGul70)%TG6{lQ`Z&4rG#J*=#VXq^6UXFj0w{Bi6hTZWN>Gha$JNbE zY$8ig%H-@gN)l;107VWur*GvJe_;|dvFTG&rz%Y3p=BOU_u89hAazStTJqNnBhsb4 zIP46l-Q&SCdCd zSi!9%mN%SsKSa}Fm^^N=BR?W>UL@0&M};Z0ZY=C4-*%A?Q(|=-;}qs%m*?cuWTEHl zlx1@JAoerseZvoAA2*6N5(<-Zx;pRno}qTbNOMlUuf-E$aLfNEKZzP&Er?czkoFGYeH4@+zZHQyEO_Pjt3dmNw=La8 z5Mr($XtE*Jc25W~=ggfe#)HCwA3Kt`=28=$=eq{LghmDxYnav+U7jkv-f=<2sSdQ% z21kU5jH+>sHo3EJ9=(n8J^oDaDC2^`8N}R(ndd!%bh&vjJkqAnjVw_Vy8LN~IYH3J z$F7rW%(%eSf&3m;|Fs5|(l zo;>#To9?xfhTJby3jjm12@o$b`B8)}TASDcUain?OTguY#?c6odz*AP1H47XVl~c8 z#u%D;nug`77j4SZe@9sKzTKg=|FjhkuZq_7NNz( z@r?E~ZQygdb;M81pB3dN)k#0PS;71iMP)NvYQ0a% zYf*SGTbNhkW|y}}EkldYfX{DzpA=;sfX-E9OwEAoES`RFFyZA4L1x%^#Il<2{5L($ z=z@2^ievBlQ2U;qKxRD84g%BWDWC7N4;deBwpD-ReS=4z=O;q!N$i6{OnObwtfSZoWb+sgy(q+&*V<$Tb@; zC^QE-5ua=yh=gr7+xS}rq}51Ns+)+_uRVy3JdBFB5>|HaKpc278qSux>Hb~8+z?TV z3?9p5s;A-fQH;2g1Ri_y&sG91a>!q#v1RNQJ)AE%PwWFoJ#fQ7o+wr)o189Fr2WWg zBEcQYlgOlztuI3<%4upV(f1y!vi01s%H3i(^BIud!nr2rG!`RZH`J-HMd~i!{*j8o z*?O8h$>3v-O`!q=>3_|#R_~MX+eoy_jkr*OB9(3s~+}3;3=$b=Ws{0LYt9&u6i{Zy7 z(`qgKx?xc*SoXiS7eX8=k!0X4byrVs$&^R3r0G@>O@Nop@z; zXaXp1>@LP7l1q=(#p3vUCuUXG>T{z)uEgBypG_#5%9}v-H1VFUFJJ3y9(JcFJ^q!V z)rCwPC81-%xjoctMU}qiZ!Dk5+s?o2zo(s>6JE0H}0Dc z>EIkev+6A3KYSiNZ8b9QDKcmzxI|>1zl&9wvU|I|mSVH`weTEqH-IJtwM_vFw(!F|POBQ@!4CO0)iyJTK;ujVp(7zt)aN4Vaj)d|IR0+bi6 zP@#(iSkQF8c5Z#9^N(q_hlKUv{K)nvmZIoKx)Q6ojCsfwF}>B+5eN?9rB|QkmyrY9i{FY|)K?P;o9f?-{Kj_)VaaV=QJ z4GxW)n!MD{KDCS>IR)7SlQw&$C!Q4N-6&Ly3N`fyYmD(>UHd7d|Fflhwh4M%NH+T;s2fLCIu(B)n%^bjb**I(P3|?9Qw^?L(?IV@M4rheUcj)$5Dp2Abf+dnS6&W5!U&;IO`?WN z0Mx+U?2%6|@OOk7-XdzqFCwD!KQ_RvM{^6#-qQZ16!%rySSquN6VNzTG;oaHe1R2u3bpC?2h*_aX{E*k679 zpnt^mDTdYq0QEEx1pK?y46@>LxJ^)r6Gcl%KRM4|7<{J*f~?ET+qgk@UP zTiEp~VG#jsAr(a`HS-LU_M-})cbvyuRqu)naz?3G!N14KI+{Ev$w*B&g|rwQl@qK^ zvoT*zfjN&dIgz;)NhlFlYVb2%+I}5)dOtO$RsU?rw^HdppaoyGVRULRT$8=Fv2hZI zoLKA*f}EtJBn?I!o3bQ&<_ynoWWguvX3yW3A~Vr*6qQ~H-YpAh91_=^7YDI2u_sX6 z4hh_@IDEs;1gR&ci&2KC7sprK^5j;VVQNbE3kO&7 zKLrCiRgUIrB>Eq*L@?=W1~+;o7Of&w1-r)Al*anezv}UiH&PWooJIzj~4xjl%*P}KM@YBteD(3Tv8r$hR=&@UpDiGzhmJOC z3cpayHz-lIPEUkquZ`5kcu(6QWO|@nl&F{eU z`S%hYEtThQd6Vi|I?AvJ*vJgEXW#n7H8;`6RfN4$drJn;=$SHjKkPHQ?6oiBMlL!}+#;c% zFMPb)Vtt)HQJISI%#J1_JzK)|QIv6oC9e`R4=1c|%TF>g6q6-FFo`RWoHq zvIJiUSP-)$Xo^A)-v%lud=LYsJ2nzmujwI{?a4`QJnN0J6JJu2!Tyh4oizc%_JGyC zXh^XAxb^d`;5&0fS7ZUwoy|n zsI+!8NN<=JS|QJ?!AOU5?Z4fA68Yb<(-HG^WK1S08FapaJg26%Wuy zqerK^d&t_Al|G}(FG0g{?-+j(g4jnLe~CfGY+ zSh84h(Mwq?yM9pPXuKck@*U|?N4H*ss^udmY?q{mitu~12!zu|Uvya)YD*35Kxa0~O8@8IW{Zq>&pqsTb;f*4)NBm-O0someonknuv z-c2@k{;W1fDXg3azVU>}@ja#{F#;E`JZIx_Vjbu)Jvn&(#J&ii@4{UVsuLSOSAKNC-0Q5vg(v!^tF5f3$@)d!rDE{u?~&-)y3`sD-iM^~G}z<;&Sj9w7S9WxmCx zmd!^(rx@S)sZ)0Bn-`VG?j#BO>pF$jYpg@-R6FL^74Rw_Sz_@RRP$}B*O?8z@_p6) zmY@(%HYf^g`nZPjw&9$eIYVJ4eES;-i{L1q=YYK}yozjiwBG9q;4k)?+2PMtW&p>$zt*=Rhr8SR8pkwSr}rsI+Q(u81#*9W|Mh zac@peu5QwhRQtAC+515AZT;nHr>yzJTu4*-^~S}-@R9kw2rRj>KIQNv=(wI8hgGMOw~{NqhVRk?t?c-LZt=O(yPfEoKOU55;YFrtgo{RJaFvxoN~XT#3A4&yUn$sd|_hhgxsv{O))}o zniCtdb41F$)BvOsHi11f2u+M`J-{+K6hdalAtPfWH}ZAwJLsUXl(-8IQx6|*Fxud4 z8!@rTC0JQtmYUVSjdHUyFZ(n%XPp;7sFOIaIE+Oo+5s2$g)aR?-1tk|di%_cqkzxC z^y)0O$g)^u!}jI(AF`>T-tGb~{55;yQ2!(#bxP#n!}KZ`@B&^DlblQ{eC`Vv`_#?M zw<|Q?Ix#>KA7T+FdruMC_hH0*yl7B}Z;yvky;(!5#tHw|L*E1Go(QJjxj(ns*E-_& zd6@7wor3b=Wfd#ymV4LCNGlY{z(C#VQx1{lko5*KE?N%lR)|I9O3In2M@hQX-t@*k zfpN+xDLxR&aC>ZCQIFsYhqOse_CZ#;sL|>yr-OS`HUsa@W=wwp$QPFG$nOs$pKxA0 zwoRtq>ZooTx%nJ#qUma^q*uKUOKS7763rYDe(~crFE*eiTiV#@SC)1UCYLGl-S^|| zQO*O6t)UEtnKfV_%QCNXT`~SJ7-wdOlJ**=}V#Tx9k~2!b@E(y`3a- zbR~^kB0DbKh@UZLMZ20Z(bjGLcmKvmgesTrz)iL}a^(ra{t(OYjrKvWh%CzK3l@&3 z9Q$C}>y8VEdB$0@?|Ar$u0_HMB$x>3vAjsa<9%Ly=56tqh6Lk5OH8TZk~vd$UB=Z` z3IPBuW~+NhxZj`jvB#xO*-4EaM>!A$8KIBsGJ1nqV2Z847Y*RtJIs4UoC>G1^q@<71(?~!1Q?Ft*=q#cev&s~~LDBo0An>TnhGcXmqPdYg51x>BO6uL)zBPMc)qT2(Dq|VD}4F@OVWVQ zmv`BGDC=&zT`{E0bcwU{yC0oIoOb(mx#Ula21_(+eu-(7BD45!no4-DY)%hCc6fm>x0H^lXICq z0cOyl^Ud%m0SX)Do1!8t)`My4A>We-l9GIVeGRji&7NGga<8HeGbk@BuU?LfU&Ol9adT_Z(vAD9tnorF7uNN^_o9IImamnJR8rRp7VB9}wT<_rfUI$W8d}ke zsr$rEvh2W8jc})R0gY%B*yG(%XC&C=HKsH3xhJrYVeC`8IW%mHk?=W98MBp0s6z0J z`#D#Oqz{Kg5?m#o%mrPk|J5ngLda@KAnorCRL^xDY{O@M)%HExe0sf(X6s8zQ=d46 zm^&75ZHF|e8+vsZWHsRe?k@xZ3=ZW9LW)A5gXxNKaMr|SVQop|rL<&k&raTo+tD^g z)p6XLeXXME{(DQMiPb%hA9%^9ZHt*+HJLt?8>BZe=g`L&`}X)1;~Oj-+ztA{oJm=0 zXPBG@6|a5HEXV0iR*f%CauH|ruN_bQcD#o72dFFTJ2&{rhYej9qnOkLWIaQFU=QV^ z=FjY_=;bFjd}_PNH-s27eMHPd(@)~h3}ljFRbQhp58HAMSb3tz!%sS|mm;4^2Kv)z zyt&4Pge^#KjCW#Q6TRE|&mmVaEh{MgVapVR&qK)KjQqGd8VH zpH_PPq&6evj>V>m(4X?B+Z?8_+p6FkUL9=Y^)XJa_ z%uy!p55Wd~VdUdz{KaTDmP?c-^n?1O+FNC{V5dmyt3#ma=1QKKtGrW@XG?L>U7M$X z+?9rS;U6*9hm5Bqe@a_M_{^UQ5d9Lt97Gj86QH;HY}Uki7=D8LZBrc6v74JQL$4_o z1|cY4-ijXr_3>3n9^Ki;m5%;{%nfp?-2BCsHm)$-&*iu`6B&kR}Eh2jkGeCR4m zZO#RhJ@_Rc6`uL9r0glp8<>am<_22a6jPP^b$)X?)!9 zGNN%~6*!R2H~T#HJ%A{K6*)0u9W~JHgnO%UG0s8Qtw8qTUdswk_$T6b{V&JNhF{4^ z8|)&B1|)FTd#~RAQsi8c{?-1M)?^LO=In)aD05(hp3-JoMR%{1>3Gwgf`MyB;e_Sw za?(MS>nLvMrZ~0&kOndOzKlN?e`7QAnJ92D1t+8Jfmc7pdgj^?wY1;j}_%y)$ z>l3@p0@4abMRX$jN9F|LWpugeehUz5gvO1K@xFbuq9)w!`D26m^-WqVi8CmhDsPor>%@M_{TB{D`ALlET8|Z~c=P-OJtw z`=>`ATMcio>!iC;N^^%;UaE!2kjHEq?r=!f(QM9z4+A5f@O>_aQ6qJ%)s|gcxv6g^ zm^3ugKjNmSD%z;f1=kz%d*d4ct-~;@-`hovQotb{G1>idYOJc?T>>y+L5Sh$YQeo`xMH!LQ~4BEetV`MI3;?>UG z5G$x>ux|16#}C|2p54rP{Y7USG7R%R+9)(1g_~UuF#4N<=%$q;)k2~jfo7ws8R3YFS0fFbCK!2)6O{g zZB+MBcU0F8NRpK$qhDNZ4tnKi6?rW(D?uyw$2^~IaHcTO zm5J4|23+$ypl&?wTn5M)c+{-*Asc{VliI5@HzFs&0eYu{DN=W@8X!Yh+FR^`eOKnN zBo^NDoEF+~I(6E9U-C+Ae+Y7T%*Tiob(oPX?d3@GK|bi7AM%gfoi+Sv{Eewf13>(9YKA)dmx!nK;5uTah;wjb}p{+fdcrAlD2Pi z#sQ~1B*}+NGJY~X{2-|Oj+iA=L+=VjsG4~_3I6V3bWM`T@QjkJ+9~wp8P9xcT-u)K z#%cyDvt~W*UqO31s|+fp8$>V=>HHqw+~dDJ{LmV1h?TXmQ%eFwKx7E`_J3WbJ^`|D zrUyr{>)$j#TZx+=SH{RiFZYAZ?6QAuzRRa%4lcm44k2gC4GG(I-gnC}E+CXBZ8k;~ zCgP}foAKIMDX&P}8}H`o0XRDyA8r2yWb0b(4Sv=*@DfM6v;^!d@FX+ep4A5^r5~!3 zxJIHG;b7RV+r@}<^9c#R|6*PCsgLc zX39+dH&n;Pc6n*gGsr>L_pVfLjl0d4G>EfliRzuo$~`emaG-Oz6+%zswkY2d~+1kSJWPO@ivTE8B8f_d-ttEW_Ht3Ho@Fd#%46&>%_p%nz!tJB@NpO2?X z!kzvJP_-1;jf!WPe0t)CX`efn9s?*ph>n4Vc@R9kC~}0z%>g~7+Z^Q|pr>*YZoY-%e+;F~R zW}hyQNB`^{tKA(-=ryxWYU9e(DB2s_-0T-4+R2Wn? z#oELRiWkW7jY4r5iHR2~UJ|U|m!!c#fl?@#rq^+!ju5~M=#bo0R1Rq1I{~8YWvz`S zTqBsuKeF94f!**@ns@`%tM$TMrbsmb%X+mw7e{+uwwj+@jszF>sEsA}iMsio0k$+tk~>aDaJ&=x zW|wgVys+`jADY2z_DCbvE8l^p95|d`zp*yX+O$XSc^nrdPY`!!e-yh{!tpq+mZI^* zQ#LfZZYT#67@uSu^j%RjUGyFe23~$sb@o}ZUscL(bbHzDPt z??^JKL^Z_t&TJB!=h;6jcVdfQuSs0Kbd{8dK8_AsiLCF>h8MOklG0FS37c>A69hKv ziNHu}4MeSEm!v~~S8fEdOn7M0;Pa%PvaBSuvfkIf3lovXyi>N^T3kK<&f%cpF{a^- ziQFEt>|~JPFeDCBZlW5xr#VWU5_1q!oxHTtFz6O!9L4R6)Bda%*#;4|v}>|f#E8a^pSPbg)wg>S0WNMC*y zeW9ry`t4cPLvQL0Cqs#deS(n^$p4QC+2UB-KSnbp;^L@_&(p~NaAsH(e-SK0EOAfu z8p66K;(NHuPrzzGXwnnI4k5bDTT2^?UGO_b_oNVyT)e_i-sh)67qFQ|c90nH->#?d zgWdeP{G-RFUq`anRUh~kaJf!@MId|Kj9#kz;)xK&BiAmz#|8jWd?ArvekRL; z_6bzIs{KrjY+7B*p4oy{e;weUjyJ8AH<{LKgy^Tct`akUCj36>$lWl$ zva=XH!ymC3bZsV=JN+%FhF;ck5N{luy;CGLA`VVr6H5|?9~6Gm+tI?9EH>j%WEjP& zCF1Xl+8yV1O=EP!fjD`yhqlayW!rh1e^kG}fXqr~%?tyd_Db8mR6Nc$s-uHmhnZmu zr)%#06Bkgv^5p%z%JqNZbz0EgLe>2n&}J#e&z*lww@y%{?zA7&EiM0dH!!aEngh0t zQU(tLU_3A)Gy)PR2@ysE&;K^7L6UhoCPwtru^)QnEIz2GUDe*oKsn-b_DOp`Cv5nm zyo&xRz17g@+IW8RW*K*7LVo;d2fJaj`Y|JN+~sTP#Aobh1uRp4Z-W8+K}Gmar#YkQ zmp_?jq;&fqO;mbSmp!gN)T{SXe)rFX53k#cQ+&ja=phVc&FkA962E@Mw_G&0Y}yrv zbR4KExJw{pKi`xy#0jc%0cQ!7)Nf(uOJbJBJLhjKjh92@xlA{`Kd|YMoMqJ})3Wp( z&KlA5<EL^H&BVN zzhe9I#oXIGMK{Ry;080ric>nxLD=#tVpzOLi_y7bD7MtNfyONGURWZQL_2vp(fG7G zEubIL{!-JR&8k%)z1@y3=HqjwN&|Ashd-=BfpfY?$vZr9HNH8>$yAE^8~A1s@VGLP zKjWSW{P}`mXNsE*)-i_w28|r*P!QA$CinYv9IRIFFc(*2MB9K2B^n(o*x<6@m}J(G z&!KI;S`3QzvAHq8#aZ0K;6H{tBWh@Hg3}O^w$GMARI#$ny>>+N76H0gbfjg#o?D^x*sgk z+N?kLLTw^Cs}9I+t`GSehPlWix?OJSFKT~%JrYM#Vbl3;`JkMk z@N~m7BdfoC@!Zw%jqma43}*Rc-tX`8Bz~}?I&8aZPmfwBk}e~ZrN{Q0csWwMc9{0xuf zR$pX>xWZ6(1N7_K<{u32+u0WGmjCS#UpkrB03!(jo}_?Wivq++`8n2gA5YvjsRZ$_ zq=HUy>%FmsfQJ8jsks4D$msJSj+=lx0UrLOhxlag?_r*WgtnG|#@M;c;%KzHAYtdtJ=MI?1mfxc^?330G3q05iK$t)s`AU7N7>O-rv>OFvAzja_AGH|jxc ze+ArGydU9(KPz!AGtUD_WaCBvCI?*V!N#c9|2)noPy`fgYRv+<%e{t`qOR=^Ku<(2 z^5yRHa3t8CsK3JdE|Mhe^FTraDJQ_OgJtkC0ip|ox-Sslj0ep3PjbEXRrm-EL5Fh$)j%|(+= zZ5?#-B9o8hcx!m)`%?#IRH@q_s`od_ddA< zj+RFN@?i>CJbiOcoV-=}23T$n_cxW^PXW>!pDX@lnV;|y(&nSmkNcL|m2um{y&qUl zB&r8abi5+g2qZF)LZ!+RCjFcOOvoeS(M z>sRlsZthbn77E(_-F>zG=IdfD)p;J2rb_03ynBOl*-FC0LLd-jw5SoyaeqEQ>BOGG zf^+d{$%wH>I$6p8?#398^-pciD{YszC!-8h9sT|NjY?r)oyN+Lg1U0bhXyIW; zUOoQMpsfmp zG97Vc{V%sLKym}u7Oo3$vnQ)}UIDazPo*e^%AlR;*1NqUF$w_(Vyz;jV4!Y0@7O6U zkC;>_?r(ZBRDG$uly;Y-hR}3LGDusF{7yKQU$z2o7_xJrWW!)ct-98ylw0_G-I_U6 z?dq@D;jD0DV7Fw?Q^o%lM4<)u$d>M84N&`?vnkddi;ZvKs?=D%JPY(|a9&9(X#rL= zkSwaa)#*}AS!7h;IO*bhm=BUqLGwP$KlweQCmV^v)wVZ*d&Ne=YElWUQP|l>y%3E1 ziG=r`@8bu@P@Nq(JL-{9G;Y?1&LR)kHeiiQDS8KHBga_!H5>13kyH%M)!$W5&0lc& z&H0vm)r^>=#fMcl;*Iiecmhaq0=RLMsRn1O%~kCQ{@;SZxv@EP!9GOeGa+Y=%!o0e z=@X!g13H~75G&|)Co+&8aiBt=_}HIVw=zlxWYr0OZkI=FR|Sm~FEH2s6~ilUmXolb z@dzrqAv3eV8c`BAvkTK{u2f zE@6e;?*NNZO#j{r_|lptb>_My){&9R_UMzuA2>dq1`d;a8Ki+!^IX^%O{nYYZ-bHr zoUk5LrQMt9YDLN0BnHt5>uWgG>4T@08Fy}KMJwMA5P7TzqtzrW;0DBv0&GZNad04 zK#o+x1^32%#+><{;OY96cGV2|T0;Iu$dvxmE#Y{aMA*Kq*fYL|$-^$a~?;9`5hf{`_GgP$;H?}t|-S4#cz@=RU!uwk`_{<^Q+0iirw^6oiSGYO;ZVuI_!xdQR) z_8YDOg;+uWfSs{6KOCBKrwz(CL*HaW7OwxB&h<9Lw$~`4fh8%{AH;^GY&cPi6bR)sZl5Z?btq6t;A?qs^zI}z-#g*@8~X;xr%kN>g*F@O9dV+BN|Ky1;?+|1wr zMh}WB&N$srBPD`m=?Z|#o(1r$Pqw;zh(T0zjZ@AOct>rw=;MX`9rHP5X;)fXxGpUT zCQ_LmB)IJJC70#;(Nd=J&J^O_%cP4SyDB#@b*7wYzkVv@zCnYX*B+eCEvdM@>?}Q! zhH`;dQsH{1_g&l;HSmPxc#|{Ezx~-OYHMy75-9T^+K(3`FSVj;06gBXgkE;#)$+B+ zO;SA!+l`=~pf5mY%WwmlZ}E7o;p&pN%mBE3at+A@(HuC#wKzy_Kwv)2pGy4Wh=Bo` z8#%N4Yx31u)-vgY!^uw>k;LBur%Pv0^{86)X*$+FyYfJVySq%gLmv)!mH;x|E@g2*PRVevBiaUEH(Lnlg- z38?e)lyd!*eQ%V5Jj`V~O}jbM!Lqe&V%neolQgNd2XYb!%iK_8TWG{kPG-b-0v((W z^s!A!d;X5p)OrlRPWVvznL<1lhse87m!vIPK(@X#sW;N5<6m)XPj~g+Gmi~~XzWTkC zMKXxC5{R%PD{-eqp)XN=(}+Wg?&rvk3^49a$x<01D)`Oc&$8NVo)wRnaD;ky&Fp;y zy>s*4pM8JR(C`@QgR_x!hqCOaDq_RTzh^RZKQ|ap9(+3Dzu%tp^hf17(-`#uV};$| zfrKyRWBKB?+2|NTAuAvr8K+PwO?T+N?|;eX>jC%zFVB@hry+zKrQShfbS36K%b?Zw zUDSxBq89Zw<>PVdBHc_M9B@y%%Dxr#Ikx!2vR$@_yLm~bLz7m~5Vc78L*?<4xtKxX z0RmlFvI@oq@}O{j7*s_!m>dO{?Tk_yn@7+c<%XEZ$VlNmHj1FZx3)(j#%jR6S`slDe7%OFk zj!ML`!}! z^erQM^mUnkj+91N3gB!~?G>=v39C zr%G=39ef_d*z`7OB|_aq)xF|Fjm_iA_g>_%8#HRW!ArHDmM!Q}F5(0hm*~{Jf?xQ? znwbvXB0_1~KXwBAu5!g9za#m<1)`R|<-x$1{{z(rv5?Cr*H;aACpsV%40`OO z;UaDNp8j4-qgFg>>$=-Ih`O;|^s+o~Zs7=@*`P4~ETzFBW4kk5u9U(UkGV_N4C=sk zeLUme;cnZ+9_J`!M$qrx6GO2HgrR(cpE^b!pZ9af!=HuGoE{yiJ7GJ;fl$dGi`tJeV=aUIRCj$rt_!?lw_b?IUKMG z)>rU9_K2QaZyQhC-ef-nP4M|2rYpn0F$pB0`bkvY=i?|2zw)Ae>)DH?o{ZQrWrSuA z5}vn%Z7GKVmsS^%YX8sYr@vlG0WyE}s_aJHS2eT)OJ2ivv1GPOSO!AA~CS8TH zf|p|6;qm{6t*;D=vTNI=5owfe0YSQ%krrtcln|6|q=)VXrA4|?K$Pz8?ovsq8A`ev z_8Q;!dA@J&WB>CIn7L!cwa)7dJ^7CSTYXg8o(6p#IJeR;Qf50s$vlSQF+V=kEwfC5 zi49A0y}ob;soGQ^;ElI}AP2PYzOD_~%V z=6K}^yn|hbhxXoqq!uC-9pIR2V!yRs!44kYT>^w#gc~j5*b}6OcmM2YavL(!Vkv2A zG^du#wB5l1AiW_Ma=Vpl%-&b5huNB3d-)tikkF<3(s^(MQ=jye?yA-AlthjOEj*X(db z5n|j-1C|4ubo{F&4%Zx~dX~HxkKP%zc&~PsG1iU*PL|tHap;uG)2=zrZfW#v=;&t- zn;xBh+K}FA(0I6gzw)|pypX<}^j$u_R>x}FGX<$AxunB`gIM;jI|r#tKKyvmdgcTX zfepgQ(E@JRqS2=!h1s#4TF`Y3HT>T^$FQ3B}*Mj~bRowkM*%%F<`k zg?a)CZu|XZ+5J4;LjH1cNLk{S)GtjAqj{{>)N8@7;pT9OP=o*$HjD=ECpa)1Y!6o2 z4oGyK7c*Z(>C)Uui_^Bd^7+a&`vGF9WHeW4+AXuhAABESA!bgqk`yctM~{xV87A%Z zb_>P;|4F|4yCD?B{u=zQch&+H8MVSY0EVDmL4DvW)27=&3ScT|0veR1hAi_5C=r3p z`wbL%P6Bz~N=XK)Ke3ro0~(~|$Z!b|kUK^hSlIPu9_I(7V)apSjg`5)PJh_n$pmS7abhtxKBo^m7g!1&STgGv4%@F>Iz?qfSssd&-jz-x2$ z>;ya`T1g)oz%UVh;ldVjwLR}jc;2r3wym ziU+P3j&1&PKscBiHgox>@19N8@t)*d-6@J|Jk-=sBYFJele;h)-g}GvagPFlEn#if zZt!KV3vJ1`;*n$b`1KoH zOb)lKwMENXub(4lRiPiiQ^QCDjM+Y@D-uC}IQQiC>M*);cGQ=;dnDg?7>G?<R^+;tC1U$^9@!m7~50%VLMKD z{MO0GT+gkX{dkr^<%495pt5^TuO2D7+{l>ikEn1f2**5Uq=>91T*?q=b;k(@(q*GN zK*ZB?Y>761ngK)CKR@~)=d|ff_N2M*T%ODcvg|8wXzs0xQ(kxTv}GxB@LDXxOtY&L z2gbb5#MCW+(f5_VZmW^h?Ael~j?`}%ASH=0mQDi(4aOHo8{j!B9rc%)hEFe^drR-^ zG=;9OLs5N?X`*c!2nROTM)Vb4Segh*?A6(CP~9ZAkU76O=?Qagh}73E`#Cr?VL;oET3V}VtGFv}bUKX`wPSto)W-N@^5=9eOxaZzP^Tdd(l32*$o7%ZIcYc|7i8g5VFeU*|OWSH8n-g1TxgV)o!j|kj=}E(8yeFJQwtlOZn%9Hvnaw0+?uu zx??Vlg%HrP(_wVSy)BaEs8ODP(`8)mFSq$DZWu0RV=-RoCkve*Vq8KcA}|lXGV0C5 zu0J;TzR#V3?;t-Nc**Yro<`bC5ezH7hFXyZET=8dImr zq;B&=8J+Xc*gW@sZ&OuFz}&Zr*PXLR2P+(KPL{xGPg?1(689oMJvVi+-iijF$-FT1uH4B3W(Nvg5+O`l*?{uOWIGePH`6?|`avK--c85@Z@Vd^dvfXLjp+N? z*VRnhDjK8TcA4QzrBCl7tij-NULTb{2p;^gH`3Ot-3SF%2>D_TF5SEB8VlI|bt6f@ z!od+!rxEtsO+g139NTf_Fo5>$tG;iVcl8E8*pB-^ibG4!K1`RDNy>q3Cob**SFdiP z+e>b`L3jJt$GZ#k;E4&nAc%i_k{m=6{0_c!tXq8AOXKBX*D9`>4CUrh?sSA zB2uy2@u~QI?dI)EuQxY+=ren+dU8heN7yF%fyRt-K;zF>4CVsz3_TmxSi_srs>=?2 zFKa&rUDUNxyY3G=>UiI=nYD{7DVg(K;a7nDs@0$CIlPHp+V&(|(=E+>z0=fMUL+rL z7J65!?&ca-HV9nr{2K-7Vzq@25T~s|hw>F8ZVJz(ND4kQj@QL;36DTFN_SU#lL!6{ zcEbsszzc#`W&N8?K8lPk*kuJ#TfR zzyP&3mD}&*>40y>4aQox)f9Y-Do4B1XvQIxJ;TSJY<6!KvRx+yrUouA%dF>B#E1CLsd`i_T#(rhKy}@|&XOe#UuEPJ zk{7#t<5xQ9Kq}M$AW3XU8ea%e@vQS5Kr4B2>LW8&H;Dr5=zwrOqV-B66*H+e!l8-K za*ow*;gUko(R^X`yS~Z+2C%-DJaONqUp=|exeFsCT-@N|vDO#C!GZV4 zY-wd08|)@nsMpXSLHw;2I{oTfEivl0<+*WtpJ&nw2MhjZ4%2w%>N*%Mv;blHV#SC=M_V8^Ixy$4O*=Yp z6v4msMEzw;FkNT-cSaXT6@!-J!KGf+yLg9$6;NR0AkqL>RaS^m*Kddk6ATe=hD2S} zRWRwM*zGOv(AlWllhDrm_?RkRbK={3>*MWbbsOWF{p7Nqg(E z=TMPx>?|$eSr5$JwmiK&r zVz_Dh7xKFkAP#ipa;nae70YVUWOd~7;(}`ar+XcGZ;@UT)$QfNE#g7|#D>DH2f$%T zu@4G5b;fmEy<>uI>@gzP<35RKgvB=G(K87ez5_#Fj&|K`c=NUg3N9~8y|}*$g6X2hCsg$&}Q=*p{-SEye)N2@QAAE?1 zl$lF3p7M0oF)5p-*6`THTK1d{b~t(*>7<*JOH1#7#tG7zJV-iuB}4m>xl|N-b*^$E z6)Scx=x8a-l<(%?rTr1I{Xk8{d@ha&tQVMDg1fn9??N^P^v4XW11E{`s3q|Q$V=Kb ztOzSNRLabHPzx8lpFR;HOmBz+E(3;vCn>&!nWy*Q{8p}Xoxqq2qqjDl(fF&u%+(6S zt4#$ZYOnj)958Uvzk09iCN!~P6xI&Ev7wgfQR&Lj4kzu^X47O?%}Q34PM7e0lM*$- zWaP5*n8_6dy6Sf*QVKW-ewJP%KvdLs+F>;SdSfxbsl^LOP+whbhSK{GlwGe80!R&^K{)r8d+2@RcH5PL zU1Q+lz=@s)*@xRowYX+uKGx!Fu%cz~n;=la+^9!wQSLV^B@N%q1T68F{(Y zM1@x-R`Z&O!Nm7{g(L<#iEF{{SnlYR4XeaD(7Lwm z0R~X8$RU+iz#9Nj`XGn&E&!X)L6&I;=*&VmqzrLj_$f80p^@@}`4{_>jc1v{kn{>K z5Y|=8$ymHl>cFN}_C3h>;j=UhS>?Dk#mokanfj(H7axu{PXensQ0bZiWW(UC$prv( zMdJrb0)Xw0Gi-c7ROc$3bAi6o)8b4=Cbfhui2E_Oru7ZEtvCrONXu;th zL`nc=bN0wp!iK(^!LvBi^m1?;u)90}p@U|@=d#F!>q4Eu7___(-e8Ak4eOe(5HL&v z&z*Aut!!=bg-_?(fbQP2mg&nKtg`M7zXdit3FpAnUJ5v!p#54pj?55031`XL6V21& z4c?@c@X~7Kt#=uocq=~^-gEk^$zAhNW8fTcosZ!&mfH_GI6SNZJJ!GRPqde`KNEL< zxw{=iFytuUwEYpWNy)@|>k7A_@=S30ef<|Naim77D$~aEb>L=AjgxRgXkM<{%JAwgv&d%-+CYU?#&2l zJb_*hm&o7DO)vTy26AKg3G_r`>Pkm9bB?=+AgK|4&!Zs7wI)~ ztLIuZv=XP8*{4}A59*(P-9d9HsNU|_%H7;MCod~CyY!j++1wv&RejgXWwSsGHvE2# z^X(!H#2y7`>SYF#`1S{QTJGpnT2GIcFXX-4L4jciW`u5~zq?Q0%B-C5?B?rQCSk;x zc0Ryiv^_fh)Y~rw_{{+;nl7uafix|~0dt^%je$6m>=1bL1hYieELL0X!XO+=i!-R*y`;+fOv3ox z{YPh?+y>VD0V8H+l4@$@Qiwy(QL@VNgrSGGU-DhN9LN*tZS0~t>$-UGusB_tgOqDJ zWwF069?ktKjj(uE>bqUDTr$ad<;yw62-+D`E~?kn(w`lCbq$H{AZ+Jq{Q-?-`E

|Gres{?2XkSX9a-s)1Q>V^f) zm<+y48hd+vb;Ybx1^6-{xrKmU#KN)$cnpy%I}GGRZ~9af@pE!0a5^Z3JYB3rVwuD2 z(4kGeAFzz;U?%sYXAftsc??jXHgy+|do5=(#qtW=@VRMpPkwde*}x@R(yMQU45)?j z(_K0)=w0m60%E&h#=8l=m8Q{+4K2l`cQ9Zy&xobEqiCslBcYsl97aB=@LgFZbx zEMrb2E{A7n3|NU6``u$5GE?7YOY8h*;yp_$kM#`B<#=!wj@CzMWqO|f#H!~rX$C!? zqx%6mpjPqJ3Hy?E?8!_j)G`wL{9vD&9&9X`-;Fl5>KhoRg2}{U2C97dS~D8 z{$QdSdfk0DH}9TVw1@SUmme?JI2BTun~%~fG?K)E==r@vl}vFl)w7etVEo#<@wyt%D?W;qHOS`l?_JZICo>C|7}VxZLl zqGq3~O(x;jwPapI?DFCVcOD>+d2)WRlKVJ~6#d7jAC03_lCp7{t-6-q+^_3-ev9u5 zIq`0tT1m_{A7_z za~*7jb3nwIJm!nA?c967bqfC)SyzJ>$sAKggh$_5s_p|evK6XUc+YO+&qv^-&epU> zVOHmR>offe?A9zTQ8B+t^jrKsv*A_bOcnz=7n%TGo0#XX3%Ru=VuTw#8?b~aH|cw@ zCS21qsv!|naT>y}o}nS-&Rrf#-@@7ii5B>udVG!yA?W8()b`2@zaONLgt?WD!x{uLWI%LZcLEGzCp8&Ak;jRow<{)`ib;k@XA~j4SD$a1TXEXX>A%*-K!a)q z!(H7#*Tm1Vaj613PJ&@sZ3iHReT#v0tVdF%_Wl)j@q779dYj(|BS#C^u(16!9h?1&0MUZ->bJpEwupFI3FPe3YZ>&FVAe?=JV{s|s zK1?#LsTO@wTf}#e^vDrWHQ()L7%$Ir2CqoU3{SL~?xr0ti5d{JSoM$83$1Efl{`p! z-e}OiLlW~t+Xs^6_Qr83>iWc^iP^XrNHxk`hlgK1v9H7sg z73u66?8=Q#cQ9b}kkPf%R^y!TNnY=mkx`|~g>&Fiv@&BHgi=j6+9KoaE!%qn^tWFred!5CLxx?J0deVk5a7f;N{h|FoTb3;jZ01BeD97qj zva=y!r{Vti=q27;l=bU5iy4W=Uw(FjqbF1mxlbeeyI77KR+&}aW{TL*rTUsH-4?m+ zBn;7p35ClqC(`@I%+*gFKd42p0?5Fut7=Vd=l)4dI~CMz_2^dY$a!s!{gMcfx*3$ z=t=CdIHaQ(!@_^rP8j4rewF(okaODcGmvADux}JJy)N-*>TEbs&bXrBH!GJSO&m$% zVr_m;io1}b6xni{Ot1Pg7QZ`M=SUu~!b<^h*XBZ*M$8kx?(p_uhG3I@zbt;emOsBk zY|czItj%$c?NMAiO%rrL_)ejNuTnw*vpX#4Y4K#W8X<6{y0efwu;IVaw%CgI@sc{Zrt%$bs7BcJ}rHaPwp^wA=bs zj~thGYqAS6^rZ4k!WkCy0Ix5ltY_BvR`-=xP)X+eoDGMRk#f?wnk7{{Bn#kEzmt2dxH6{L3HQvWR-ievnrZr3WM?UXypkfkQ-yF@eo$`QN zkonUzy^#(^4Q2t5eA^#SFuOkS<^xWyvLvg31FyA4X5J|(R3g%xO;Z)!LE?jX6#czx zXXGh)3%y5G{>Jy$uB0WZ{qQZ{@IrYT-d6#h?H{S8NpP3dZka8J6IM^jT*73x4V{E9 z)2P)+u;?1S8>MY3=F4FnUp7Ct<$l^nh3%)+$9bNgQCGa(?QTZgy-Lu{y_RyNL1)!* zIo3L~2q)rlu@HRYdfieW_9gt8HyapL?+tBeA3O05-zCBx-3k@9UzCCus&05oM+Mr1 zKxSUMMFN7v$Uw|)p?1Y*K%11-Q!aX?%FqKJ3DuWNrlr0#M&R|9W#KyFpQA)A<#{yo z?&9K!t~gO0u0(t$2OGG!$xq6E<_{+skM&kZEZ19Mf6fqA-d}0CIYngtgB=|m5U3U( zkDmzFeD5)}OMG`z2OeU6_KwrhuV7)>9VWuPRC#cxMzK+xu zb8O@HnsQGOozj;Rd98P*8eE<2=YA{lwlLNTNaM~>sh=Er=6!MWaT%R6#3&emyZ5m( zfRkGfK%DX4mVO9)6hh(9I%w93(Lm0l&one5q>Ba1=6gcgngJ#T<(MR41SYS zQ53LF);Pm}7HmIy;(tKlF)c_DDOklsdWY$de6&Wb8p zbS)Qf`xYUGDMRg|S4EhosSc#Q$sY)fwd|kx67`-EDk|99!!Q7n4#{5?3Yw(Qvx6j= z=L@#*+?cY_wHzj-%PoeLY{L2*Q8$L~k42U^ z8^hntmnMH|ZNEeF>`wydbhP#yPN=NRt@;g)wF(0MpJ6xWJqeRuqVvq9XFyUe%zZKs zZj{D?=)LwzF`(M=Gni%H`>o|3nL*Ib)1vI-da_`8LIeHC@Ep~NN=(^bha$yq?j=3* zuqvKu|7+mG|2ua*)K+XhUDP$<-8aDymV-9~%eWw?IUYU;?}rzp`V5{^e_s1}0}*&; z4d}AVANVt5PRJ6#`Ap6S>m-p$p-jmbnDwHOKE)}b@uxxcC+gV^%Z%qDzVnIR`{-rX zZbbY8WlbB@1!_C-6tf4FQ=Ssrj;Oz-gK#b%aAg0C;S-1{F?jjoyS5s;%)#T4yD29N z97R(8X#8gOP7)zPy?jsniR9OfGWZI-ztccjatV1bVPqcyjh>Lo5Nyc|OyMPyOLAoB zP|ET=0c;RN(?ZVyqVFE>qB-S~xaW;*7nbk_{uQE2{87)5r7Q}uL<$O2*wpgpjWOWr?Ynrj4W2ZpWUx7?Ia&+32f&VEPnEIB*1wkdG9aeP8Q?n8g{>=mb3AZucC6pQvpGu>*4P z(f+cUkglQj*3)U~iu@@YyFAck##NU{xgdhVqD-(p53@5fkjI{yspO33DjNgNCMtSp zXf$0zjS=)2jH9=Tk>>2uj`MuSHGF;nU*Rj zZugC7)Gb~l1*GYD1+C`K`6?tjnjF=axw|Rc&E2bT)V>iKb_~`=fr(`ldXUAnQ~#wk zrtOqfCl(ZFQbnS4%)_?M^J+-%^lLqe9VlI+!Of&t!#2l&6mNPT)4kC5JsZA~%4?hb zi5;U9&+5Gsf-?wqUbhH$Jqp0~=uzt2;nZaDqK|=Bbv6@WujIdD5m(v@2efIqvV(^u zqV-2!2A&SHb0Rl+II|bKiGkjiM9=96xqINut z)Rg8dI5@4OG6PDQY{8Y07*cUq8 z?ttzv%Xzq@=Z{9XVxO7lI==Qa(-h5D31=bQLbx0`Y zAb6Tl_dyrEMP&7HZN*t!3t%X|2-Hypm%-FnICsVK_+&kl;P=HMAnc-h)J{_~+K$a^ zTP!s3%s6jHx^oC8Yv$qg*KW;!N(5?0>hhH<94i-^)?9}9P&o@@(E#Z$tV>_@53@%` z`NrYHwP`ZV!LViuOnZ5k}d{3#FfEI6njw)Cf2jMF+&&X2w!e-Y3RRYMG zRG$$VG*d+Kb_EC_p%5G1nvbU%ArXiE_X^DTE2FzQD60EL)t;b)moNi`M^;E_x2l4K z0t?EX&)}c2qf!v-qj(Z-db$Rp6oX+(acurT=ba5&TCw2@BVxP1@hZ%2SY!$*H9B-m zeau1{`F&$f_2$(v$<;fhuUyQ$)|^jljL?ST!$*9Mk;lHUDd~rpWWw+dvr>C?I6QNX^h$ME{M2Y&l=hakdJ zg}PvbKiHpC2vAhGV!G%oKx4Q{w}qVhel&H#Fl_cp?3t%Qm(CMXj!S-?$jcEUS`5B0 z(k$&C!4pv7MTtGPd)(*lQ`DahlRCP8VBC!Gs?G1#sWjg3+W60!Q$J{bm>MxF*;_@0 zIIBB9kIayCC5aZ1j%qan4+43_0ED?=qkXw-)-dzoH0`j5K`um;X>*{GSBY9gG-JLe z?$Bhk?`q8UPj$F;y;H+&^JKi_UhkFSmj*N*yrFtZcFKs*?RYHXK(V)stL9@iKOb#N z50~C>*vuE>tL0R>n4m%l-khtasiaUUA#CFj5Cx#P$Pn{hGZ?TY+Tmu9KXgX{Z5j$b z$(kzZbv}cnAk{yICGReI4SzlGy4vqC0_18O!A#zD$J=<%9>|7a94V|=k-^3fNpLK!vxcv8z79z^ z`%l!-alrWnUrFWxv)q_o^ma07E5My87i*FZ;`aEXr4W+wILsK`HmV%7uruW5SkLzHYi(Q&oaIOVke}ot~32-KK);o<3{( zVkfh?T%L0YGUhGbHq(=*r%U_S*GZ3y8(cnBp+mi|XK$A`hJFkCAu3rz9XhSJCBlit z$Ngv`(S<{xe4XDw9vEu&i>2RMi@_@0=jxeF$AeAf*|dSf7xLalSci_WRB?Vrn~;Dw=S+RU0;&0jy4gvqfFD_g$vl}-HT(NO5I2JsPW!y(l@!?p*`Fjb!*5NXS?KxYy(S?l%M-}pRFztpG`HUFZ+O&X*@V8I^A5kTb8u< zVvv~_(#wY2ga9V#!Q3LWuT&eYeA-Uszr;t!k4LfXjg0%tyLW*&Q}=N*ng$1%^lvJb zt_sja?ezQ1OlY^&YIhF%{Pq&-fR;OHUeaKmW9VVjP5@PWcs8u=iG@y!nAb=lG0`U* zNr?!k+T~}Lp9?RQE7GR<8T96~gkvTW++{m1(d}e0rjoSVj&7At?aUH{dWNg9 zLcRL66dw3AjE7D!OlLF@c=%OtSoXVZ$#@l!Y8`+li5BrRD}KZX zbs6cd!lgFR1D(1lsTZjWE#{7bg>T7fuA6urU)fFjEbI&ti;`MOzYrmqhe$b-;}R@7 zIm*XyXf*RTM(GkI^#jCf_~M#%zAScy-5AkHNg>v)S?SNvqpCDqFq@A#Zq~~@Fz-_rpY*E2`N#*5;lfKOqLC6ec1^nX9JbUwT2)Q;tNa3@5^7sx2l zLNG>4Oh8WM?ZmHb#8pPKMLc5Ibm-P=CILE<1_;wN2ck@;D@uAd|pv{7N|MH=*zReX;f7lG*_ zqoEN-oq))VOut==O!jha4v861BM9+YNB!EaU;I&Z32^}$DHDDUsi+6<$1c0X*iN>M zQHd^X=p}u0P$C%|ofbD;DX(P6OOj67WCv~O3H8`f{}0TKf(**L=IbrBhv z!`|?NAg+^0V`rt}0pn-U3~jNk3X9A2)`_jXlsgRSMoM z$C~$p#CFId#Z8iB`kb{eo-gpc_SbnT{7(TANlGjgRrgMhI^A{OuUq)EAvx3M{Ws-+4JG;RVdX4l=Q3?k0+ zFjQ!BX3dDM7cqx${W0x>`G#;Rg>wr(1r#yY!!mG4HZcQio>yMJ;-^EpIZ`{<|Ip$c zje#`U)5#lK&$0xU{Oo05VN!1bvL;uz+rJ;{9_&W3 zN9zO$0^K5+=%bg_p9L4q^mO9i93xz%qAC4{b|0aM!L}IX7avqQ_yA>SZWc~Whoyd( ziXYaHxc%_1q}REImq}lOnrG(My$J)G)ny?xn3p8kE5&qST7HxMLg`dVTd^u36uZfv zYI+lJ`IyTJWfdR^ZEY6?P6a@f&HcPbdp!^{bxs;k3K#w!&VUL+Bq5&R^*2MFIDC#6 zPuNq~ul~4R^8%W0($oFt9i|2;$Dx(g%`ix$2CfYzt+l}#*90_k>Eo$aO<-w_-6kG> zvqaZUq}twI_Y3Ch$Bg4%iD*8FT2xnJA;)%sfGw{;Uq^m=%IKK4%WQ7DJwCQ~`a4AX zllWM@qQ>IN2fu_nc$0DOlDe;}X9Q!Yv!sSsR`Ghpy^bKyq@%X6DGrk0J^C0K6c78p z$9hSP0g0u~gr2_V(VHi9=d%<-mU)~VVs-O@*3evWY`mWM%aOg)Bh2@i42;5Vv>$08 zCu#62G8<9D$&VBH3pu%sJELEC8cXJ7Uyn#1q(I%b8{8l0x3yc0*QehXHB5SVn*DO1 zx_195my??KLbYgO^h)q&+BTeJGf=-27q=S0VtLAMOdHR zR`T4OLC}aWN^z+?_L8q@1>$Sj%hR~B-}_R5bzfai8B(zH`ugm`M$)%A5q5UwF??dJds3H1h&l2xGr+zJswaruW zvPACD+@zG8YNXT4YaL7nE`=BfBgLUX zI~mt86&7myS6YCmBpF@LyvSsAKXlKac zQ9{`T;+Hooru_vzb!evJn#{}*<;#3GMVyZ&W%q2q00NIfBICV&jpC>$T}sQ~Xq{_E zGfPfWw;tUuv^n6|R#^#%o`>$I` z!AXYL-JWV}fyBa})s~xlYwU2>GP4?ySd-%S!5r0;gyG8I=@$WQPrmT)p2hZ+EYGe- zO$ChbRl{q|w=Z<%ijaQE<6>+b?l8$Bo! zV7ty#^U&#YbD@F0C9BKfvRvOwV+#bj$IOi=Go3JZg+VvU;Ln?#dETyHt(P0bX^(3a z!jg{Gr-%S~r^e=WsEW;R8>@#|o}nSHP)yUYY~i7)@11yU_ngeJ4icg|L3r|NH+8@+e?th{FjJ>f1V9{Oz`aeW{~HlbO) zqt901iWT-Y?l0i{7oISB04f=rwrpSPSXzhjx(s7652};NzexL+Tl=Wjc|(%&k|3S` zg(m;Y8h^F@UH?zrebPWVIrMiS(agbwhW-4GRKfr?28DUBm4iLhwEzB(Z)&VrR0$O` zmDecCd5?mR3)WCZq8o8(HL0lG1e{)q{66f`v%k2dD#d&gQgVD;M{_uv)N_*8qLa1l zI4!b$SBpqCp~{Zx2Jzm8?rVi2dfQhB->(N7WG&G$8{zzV}ep80G>t>5m-kf<29pa;>UO z69(S9C2|a(%fq*V1O!^{2 ztS%^9^V=)gylFzejbISmT`*|Qho68`b2_-cgsvQbv<6AYHc#7Zct^C&+55ZRn_m0w zjsFu3(d8diq1=QV5rcAm#srI_j|zxHopA!;lyIuBn5rtSC(U+VtRv*tEWBOkZX zG6nJ?dVOC(3jr+#*}k%bqTe4RI?Sn9Q~Q51d_9E;h$gxbV7N;BnYCZpmHT_XlJ$uB zG9_sS)ON8D+WAP?H#+fCO}{tqkeG0WUJQvY-bDTu9--;L#i8p@6Y)t@A%y?zK&No;)m9dk4C2dq0|2yblH{ZEdMz1M2`CJP&Tn~e**Ez(rr&1DOe zzADe~@4TZd;LHIYE*a$*f!)IbFEq<1RtkLbe>Rp};eP}Cg+{DH9iyY5{l>L@^Bjz| ziu@GqP7KeLGP{hHmRr%;(*A$;*<7iV29tXvX6(`Wz9rF^xOq!|&bi!jq<>GmyCx*m`Pl?^_9QrSE+iu4&UFGF`xMJW{o1 zp7~9ybO}XxM>KQ@dziyZAe27N(@sSfOBmtrkz2ID54D$ zSAE(u|2%_+4QO27uJ(kJtE3ps-WJ#5qs=2Bm4Ql|*)b58_S|Uk2MxZyE0exh#rq-RgAUX{ z((n{`(3w!GsHh}vmmslrq?GL_s_`UK>H}5!p6La@Lh_dfK9BV1&y_@;HowA;3+`4D zk=h~gWwUBLMyFK^OO?{l8UPoM%6%y7 zOy~rNN-H(rTF7t5nXZAN0-Lh}(1xW-du(`FH8S;!-UyvX))@isHgwp1L)){?V2h(Q zqCuTn)OmC6r9VQxIt>mSuV%UiCYP#qA!4D=6L(E+mnOPPIJN5TM%;-dNp zxy@qZxZw7)-`CTQVp!4X(fG53i#LI%)Xc$}z$?#HfJ6Glq;t#cRMjcPFYB2=(oI+7 zNU+|hJ?-Mf&lNSkv#MlLw&W4@flt0f zlX280&jP(vh}YSkX!a+u8f3;DZ z0tOSP&&9wPwK&nU7C@P3nqb^FO10@=$K~p!>{`$@5#&Iy$9&nJ&&rH8LoxG}pjr!d zqSpNPf#uwnM#@t_Q}xM%*8yn_$pG)W?*Y0fbAA8JSMW{Wfo>PO#Ot6xHMc>_-n1+- z6W~d&K(%rD>psvLKZ>bH{^t#V4+JVcxL;)8Bye;#?0irBs?8dNBreU=Mi0&!6cZbA z*S6IFPqi4NRs|`{{qOh3(E!V!*jE@DKy`c0nKfFdEdZ8OD{FCa4|IoY0N6~FhpEs9=}4kfdD!~_7* z^mN*~Xz5lg27bh<9M1psUX1oTjV555l_KAYZp4!j!34SU>lBD$@_Y(Rf+ zZ*NZuP!E20dz@^20Qh^dX4j8@&S=P6LUx?<9i)0j>LE&QH0jK+206ln@%oaiS)hb|~ zWP=5$d?AXClMP8A26%*$60v$4{*S3U2}3%+wXbN7dI3@a(*#_Lkt*H z`~tBF6GjZKB2F?jHagfLF<)4%0ONA`H*9i~n*Lz|k3Jo!LGoI}|5m3+e)*#Sq>}&P5d1Ch3j|uT&&Zi0tD5YOk9=pZ41n>em$&LgDbnx-`J+cKk$3$dANca* zAsAfo$)PqM}~u4td1rN+@*FdAu5cfaePaEq0wDnR1m3x`)dOWT>QW+kQ>iYTzMk_ zJkY>>{51jqO>*$|{=Na=5Q##{p~<=5zOjrCRg7gmJ73GpwJhR5nkcXX@>i4BISI&@ zmjY2iFM)jj?EKsi{PO_p{XYQPy1;>S;9p}n99VO8kkf(N8%ZV9U;7^X@IUf~ ze_o>o08Rd=*X;|`Jedr}B`4x~BN)e{KX0qpYj6yx_|nmI;=DjynkMckJbE?y_o;(^ z>e~hvE|6Fh-Ewn*OYeIHnqM3tU_gDA4#Wbd(Y)4#f`2ap6Xum4BVwlqnExaIsEyln z2h~bsc-^!A_laV{kU|^kTOd?K&nwcYJ4kP9P2hTqQ5Kz!#5zU|Or$3o3E3&Xn^!l}yBTj)?FC2y0F3x7q)-I|Op5Hh{Y)VB0h|H^U^V zEhoBblm0)C0x;J-{i_EOd9a@qsAdNQz16FCQq*&_Z&e)tR|U3sgEq9Sg%-(0XN!Mc z$Hr&*s1|y(^yIU72?>c<&v(K9Sg3MH?SMa0h-W7SaZnG|&Tbbi)jI?F?7NAR`+1-%+nE&wf`jL{ z%FT3^?y}yVo}6?8^dV4vgf5^rd??O4iT>AS$_zEK#YQPR6H32O})+^{@&Rz;n;85$X zR_*_K%Yu?*|H>%dFvAmXcE!bs&n{U*ZNmnB?Qf*u)gFX{S*yc{yi=rh4AcSfz{IQl zE+QXK_X$pVtzaWvX{;h|qQ!Hcc_oIL|YT5^%p#}EVL*RL-Erg$j#$;(1&>?NUPUOzQ7p`gY0l%jipL_BD_U`5>9!Ul~ zhs{4R$vG4*v=$WKB_LR+{*L8(lX-#eitvE*3m!NXyym>qf8nq+he_*7FaPf);gh~M z&TpLQn{E|5Ro`KD~ zd2M$4IM{X?Eot&ylgetzcQZ0Retu4FuJ5|_>klVtWM^k{?_MVLWXVs7CriARHt`X0uG4u7dLg1jICGSb#VzHxf zo6ev2Z;oLMI;eKdsQJ$SlnIlcB~O^FbyfVyBZmh+e_H0;e#9U&ud$-lAF2eiN!|YIv5fpVrg(|Kv$ej_j4+fGfK6MRSr0{{l-kJ-ZXJv9Ybp zAt1MV{so0pTa(Pu2UEm_<>lo3zJ33Gcz(`ync|O&!Z~5z z{XWOMuRtgLKNl9yw?^yhuVY6}eEU{bvDNX__Vfvpe@O!4<0;$Dh{D@RoKttb{2vB1 z>jY=f7h5%(bLY+-QnVHq7k_y47}Yvz?L1 z6OFp?HED(tPdCWk`yA#Z_#on^1Xtk{uv=r^bbq<|8@Puz=*&a)oKTr+!<7e~FfGhm zVGm3|b!?U^gMhaL>P~GAdO9ohQ?tSGgxyn;4$Kh_(Of4RwRT^+#PaUkp1=Z=gWWBU!3$b=WFi5 zJAsN{jZb)ro}T-0&DMy-LV*P_U#3QeJKpV=bzOybXvnAU%y=he5GH9V7gc;xtx4|BI% z6W9E`c*=1Fb<>7O;HbpI2mMj(tRF4%wmmYKtb4ZE?#Q)mk@4k$K^OJbamGl;{FKoD zmE?W<$p(X>fDP1W^3>{a^aE0pxPa*&pY&((Ogr<6=j*KAGXR07tDnm{ Hr-UW|89A27 literal 0 HcmV?d00001 diff --git a/model-examples/reasoning-models/mini-r1-zero/tutorials-docs/imgs/grpo_2.png b/model-examples/reasoning-models/mini-r1-zero/tutorials-docs/imgs/grpo_2.png new file mode 100644 index 0000000000000000000000000000000000000000..4e2d59e033aa12179e829bdc8a4868c7d4d9f8fe GIT binary patch literal 114617 zcmce-4WqqALEk^^bX6a;P z75)FdK-(2N4gG(<{O`N99>Cs(|Mw67E+!DoUHJch5B+%M-a~2_^M5o%$&IG_KW+Zc zv*`a9n~6fx?S6SWZ-05{emQ==8&d-AWW^Fa30lzSTqt62lL^p;dpD)g8`tb~+o#lb)%y+0Xk6`q6l1_-+#=kzQ zJ!A1`Z&O|x+5sv6#e}cLaVP*ssl;#_x?=`1(NAK$jlTP@>X`8|N4)KO|?q%{>%Wmv( z5HyUYq1j2x=RkC8Mf|Q8|19@zfZ0nU@JPFzS^Q9GgP!Jo8k zQ1qrUL5yTjkl{^T+9{sI9r)&|4iHpe0KEP;R0-GmUm*p75LueljXwg-hfcoZ1~&oz z@Gd*ooV%Y^RdjJtK6`il&1uIrh5|P3r-Pq!%M>bcVnvRA8f)Yoh#ed{`SeqOj%Wi2 z8Dxx~j&!}_{MUEeUY@pI95tp<@Mx&~PHKnY0yu%_QJx=F)ni{B{B*n|G99~lYyJ5nw~rIM968fzZu=bEEj5C(s70qp|32Aj61sKM)NEw4W?PnX(@c>{_U-RQT<~LG zJ6F&}^etLi--cN<2h99Veujt5VjfTYYUYeduvIBo(h}Nv9_YONh{>E?1MPy(7WMwr zeuPxQ_Qjcs32ZLV#N=Uh$x9>XGO1m~aEI{8bx8F=5QSIe+<)t#Ee*aC1Fg94de9|m zX1v9OHMq%rZ;)y3dTy@PO}(yv(8ofI@5gjGEfmSZjtvh)u+LwXf}fWDfEKW)E}FZa zPP-&mr(9eqmzE9jW~^Xh|~!u6=$ zi|i+BuWFBDjpo=J9qR|mVdsTBeg=z5{Mh^Hx(uH&m{)EP?6C86J6V(y(KMB!Q1DK2 zO6gl5{BSUItMk8d64h_@xn8F5Za8uyl`>XcZTDV&2X-B~-K| zT^oULY<9Dj*v0WE z-vvby=8(!st4Vnz6oRv-Md!b_d5sIQfURp>END>=8R!5C^qmxbeJ6`U$*tS&?Rq$C zComm_vU>{be6IHU9PCVCnDHky;X<`{NO{{m?_nLuB)!CL_ zSigY4M|(*cXUIaMy}n`iE;G{T*Ed0|N<&zDk6eL=D((M1PJg)*CGQ)7kw_pFraD-J zUxIJFgO|j2Ud7pHRD}t4Y~zXtH$`5&z^Y%4Kt6+PHwUXLgPpgV?UMg=KycM=7B&j2 zZc85#3@bCG2QH195=u`yB;9aHAZEBSGa+Hq+v|Vqmwc4<_ZdL=Lv8pmKhAf;92R=1 zjmC|ujKZ3zMwJ{?I;Zxs2s2j?gt#94Dobs*l1-k2q3Hr^AB1~31X5++{T!)O^GvtY zWT73dc1x+aI`&NkOrSLMey&M;8blsrXi!Mpt}uDB^H;)0)Id^>@(#4inmu1#>hgL4 zD)A{EAJ0qx_5%b&jIYaePagy7R9qT=p%8et&FAPaw7~Y5C0FLet?wdUQQ$8wkf9Udv7z$4Y~I7wOFdlv5}|>ce{-7` zUA%6-U!N9+Ldb*`@Gm}<(EdoDASYwh+P1}sTHyWo$mK^R0N>O!LyUj@} zYH=oZFtCz|kD5<31{=W2IwlXKvQg@`dhIPH(48P@YVHXbFt!wE0@o27Gcr(Xzr6nG znK27UmVJt)u#;EG;zp*36;5RVpAJOmy|x|?YYXpqB6tkPw4{g^b=N$G*pF=ZMm!0? zh9|T+px`KvYjIi`snr%$E!1gK^#WRQ_V`%%EeH~hHzmC3l#7`6d{y>f>vj4Oi)YkV zAe@!rO7!_|_hHv(8L-pbdFmKQoac5UxEoQu=IJKbHX8$bs4@u?GKIwsPhc&sFdn2( z9lj>Jlvn|kNj@ER5`Jz%gQ(+GeotBf&76z1rxI?H@!F-%{FS^+h`VCsd_GQh#p}3* z6vdq#o*e{GR4VmB3@A^bH3vSNWVYj~(C|8pz>ndu%%u*i974=#p822&p5d$Ab25u~ z_MJvL-yG&u@K_u>7W>Sn2OwF*2mj=)39MbK)DYuk>%724y(Md=4?XAE*O3;9j}3h7 zKIfKP{(6$nkbzdg_E|JtEA6X@&8vR4?Otf+17DZVDN(i0pnivDff%l1<+w+8ZDcOg z#Lozg%2{9GehaQBCtok32Fi4G+zL+P7G-12erJ9_TJ z`oYcn%n!#>?Xv;58{Q`QtxbN-D{o=LExWRTzhF{_BoR~XkIjvvxxFAi`A4K_g%|J6 zV$c6Mz{-=AwbON781F~6@>NFax+{!7h`NzO*bFA|^d5{~ zH^B}m+rv0Lrp)r052*r!hebDM*W&1EBva`jToPdBbYC6m-vD6yqQ$r43%3AXDEl+e z{)%Mhw$|PwP*leUt5ZDs1m~V*4?3pXxxp)^DyG=MAF!I8_Q5^9c-tkw*7e)s(ar3v_p?zLU4QEvQbS7v zWz>!8_isa$>?+AhlRxrE9Zf$~UB!YC=n@At0tE}u`0PqMsVcH$e=PH)?;^=k(6h`c zOFokB*p`qwx)5o8Ft#x0N=ZeKw^t?5fQL00azU*$>T^DdehnsJd_Pb1_XY^WO>CLUs%{9dUuX zdf$#Bm`SMtAV&c*%$V{%j6~r}19plZ2W}h^3152p?|$>tf3!6ZLgEdGh+~;c*4uf# zghfCa6sXZhZ%K*7*O5es_oSxxHE6NgHBt7pWN%>;;xN|MpHBX@D*V`j1fKeBxj-)f zWq{^AlcYzfk1+2WTS!RSwOY*wN$$PaV|Oa#C_UU(caNw>#5+S128$8&=X?{w2?P6? zu_U+QLBbCd0CK}nsf53kIjeH9kg$<=BHVNHJ*<9LMKl|?2BxE-MOibP6qy*=9IE&x zE;dEXXKKg; z9t&K^?Z32P#^1OgRC_g6Tu<{ct*RI^TjmNoc}|vJ6uvO512aNa@vHBukD<&no&^Sr z(shllHJz_>Hp*TIjVQTzGwXcqHO0d?$5bWOC?ldB+y3xZuh0zZhUC?wgjUvFKnOd- z&+TOfh5n)&j&BncW0)B!kgVSg37$Rt7AX(4ZpF(1+B#r~W8XzlO~WqlxY|$Z#$_&t z|EO4?RxOL3*0AEY;{w7ejpp!odu|eW{D84;d9J3ijzB(47#+e)skQjAUSL+kTK%XE zr4yo}Lyk2$a-AX7V^je+%DGR2&&$Ir+80$9l6Eb*odD!HFfi$-p~V<=AlJ8~Dhg4R z(OThomNV#DEpFU=Bzr~;;`6c?=Zn!WV5C>1=kq|N9(J~^_ST6?tP%idt3w~&

Xt+T#df>$XElFfiKn&kU9j z%c_v9_B+3WkvHdpN5#v(mwDp zc>ulU4K_`z%TdH*XDwkO-yH0;>*qa0I*XjM3DI(#5La~7f&TBD-NQ%dSAMRTngh~! z!&Ep|2)hH$`b5pDSM@%{NA5!szM2opkBCvIc&N)L)P>jg-p3S0Zen7Mq?kAEgl&4U zKUG1UuOhR_Y2I6#cccYMQNR>wgjKG*uE+y3c)8yfRERNKI@++#>?t0SkQaXE*@L#F z=BCW_Iq&fzCRPf~pnB-T@O|QQl#u3Cb5|sIX`WwcqHiN!VM36GmldRCH%_ zT{?+T#5EFy2^tPv#gBkW|TCu2fsrCf-i zvnaKXKlFVQ>NG`&B>e?9OC(3}jbC-lAo@gf@A`Qq7|u2(TF1^F_}JJaWp*fZua3pp zLh%INc9E^Uoe^kq2)#LC&Ro(n@^BLuNwR+bfvd}uibf=Y@`xmGcWa7FP|Rxm5Iax- zneR6l_LU(nh3uqwOO)0vUyBMVeKJDWt3GIQ-G;L2?ZL|DkRr+_789@#!|N>Ag~>JfT6}8CQZ!3BM*l(!P%xRwI!cKR4FpqN z!Wz$U6r1Xc_3%i|9_oe!D_t@|BM$a5%~I2@2iy_2{`%)#ev+X- zv={y5w2xnLX5B_>|)hKJ&iVfK6Mv>cw4WH7+9d@NncJ3IqhKhkoCkY z>a&LPgU5b+pkw?mee+470EVi05ma@uWtVMfATJam-=#bzIyzN6r!YOBfYi76lAm{o z86 zazpgo`Jh|26#d#j(Dus(*P6y>ur{yWl3@blZL`7Vu}UJht)M|<-UbcT>v*cAOz|iq zm3&mGi z2q_UU<9BG|xdkhS%N;zb);5Q~cF!cz`_NphLOaZwk1g5L{_cAcuOqH3K)2g= zMeb+lm+Vi7hA^ib8M(#T0`Iz=3^@mHm<;^P=S>ME8K4C&NQjSY>4@d`LH=jJoJHy+ zv4T8V9hAAusuiw5wwpDPmwqNsO58?d3$UNSp#}X(0#oAUdw1exp;%T z3z!yYQ4Sd=7;8z=emoP;9vh*6+i9$#hB&-)plZtwJ%yfLhgS*;58=vZBicSGV)1mx zR;!KDcQi&oOk*Hycj6f(qmy9Ct8+RR>^pd<1r@eCZJ8x&mzousSY{Nx)W`ilp;vvuI^G!I||Ksf4k;Xa93HzS>WB)y5E!9<;6 zHW<5Jno!cifC7RJ8QH44A;}wN3ik%02ER+;A__?E_sFw_B=97LHzT^wHIt~HFa2BI8rq*#uF^R2OsGcaZ&Bv3?%}li5py~n`z1na= zpAK7sSuWD{Mqy`dVj2_TEjy28#oyoBV1_8xF^v|&zW!iIB{$4?`jy|}u94GL=26GV z5LJMt%g+`sWvS&NC#_@2So0%Hv`WmzW-yb%m9z8)A z+?##-{!N*6#ZoN5{;*!6H-+6{tOtd?uFCGF0-%hBcq-hQ%d9O8W;qgY*M^4-Ts8GI zA6Z51>xVx&;HUwzOHg#{Dz>u(+6@aFS&Q;}P}Y3}l${B8LM(v>MFt710o@aKzrp23 z)oEl?$#HJ)1Q=Qp34Ummy2z$EfDkbuR&arCxo?Whn;XV*6qt8>9Kl2$wrO4L?oU#I zM>#d-pO3y3Xdf-kMqP%Z=)y=cijd#tQLw?M9R+&EBvh_N@tEmse|SuBZtZ3IOT3*9X| zSz3^$I3gtNQywAw^}2dVj4ZD_3P`B0oZIR}&TPJ=36*6FDxpO7WhRe;T2Io8%^`iy zt!21|VS|7QFE+1UqM$`)F`4GM4J7Q?{M$rxw_fUcx}>X9v)+{FMO=vx%2|4dMe0&| zyFB+?>~#aU{7Lo0qT^zSr@qT~t45Qdw{CCzwl$XAmM7*|6ZoS!OdYMW-$7$;{e}tF zD@?954HEY@kR*YrT-`$PSle zfY3%zl^ZT;NG0+w93GCAF^@Krjf#Zq@9kPm(e@nowJ;ZcN`5u9$7Hn_Fm? z505m8o}yC5%=9&Qy-~si*+C=F$mU5+rYztq-F_KEWLUyJra{aTr9Ig;IN)lSGo2ao zdxB3>FG(|uEvnIsrDdl@zWMjyrUmbD1xm?be=BQCRPB7)JGS_Zo+JHb6kSjGS4P0v zUC&2E+^YEsnpd$wnW^^?@Q{8veNNrqg18Z7D`CN3?D5a)H7sPiegR^BfJ4x5%GlVm#fttzMw^ zD4X^eTYF*OD9H*`w^{Afo(u1Fp-F0O;-;qHRDFTym*&s?lv=??5gcCo^IT%GELC_o ze1~lz(ICK<=}VzAON*VL%-e?(`iF+(YxD8%@LbA-Ov#6swGt#STr$W=IRaBR#qO+=0dkC85Z%RTOl7Ozlrkwb%oLLV3nli}!d1QId*l<8kRoC#er8h<|8seWquWHEkI(5QY# zyS2uqprxty{HRC@NdWM930yy80KoKND_3 z$nNPaZw zlno0-fiF@?{Nb_KI-@@0FBC}U@oKTI#4YwkJi}W^80XE;%tNv2QAvPkr? zVHU5Mr3ga1T#>GfSk;v6FN(dH`<-!K=PP4j1PMzLQf0&p_JiCtY@)-5q$6_&0IvPk z7Vjy};`R1z9Y?=nJ9>);9&He6+Yn@uDJY3N$uf(HC9~q9Xk5`eAH=)(q2go{3NYz? zX~RCB%~F!cM-5y~r6}bMm9YB?G(Q*-7Y*UerV0>MlH93>SRqR03`dSmWCvy10_O|4 z9B9SE7h~O5dD&F)!5r^pWP9a~9ne7T_!NrxnRGG~ZdB;1rl#Sv4L{pY3wX#h@;qM40a>b~7+A=oyWl z)fw5yrujrPHIG(D*+rl{LX>nGDV%G5S+PT?I=D!OZifh`%RH!_J{#CbiUU)vbVrmE znUwOC2q+=~(6vJ%4%}&;p;>Q!?h0P^$F2cKJx{{(zb8?XeNYHB^iiLgb6Q8-{*FQb z@Oe_tkE*V-rw}IO#%5ohTef;pi*)D`XJn0P%CHm> zj0w6DqzGXtv!d_szT*pG3j(tV$$MiI9y#daDvm}fvwNh7$+JCF(mk?TNCEz$*g4L4G~>u}CkBIo?%L0f!!A zT{W6ng}{YaEJAx*L_P-j_+AgFcqjv4o1(V*t-~koDgDA=6DQ!qgmxXEu4qNq5G}(P za(^RDjH(?3#16ey_aHR2^vk*&wm4JCN_v&JqWx#U1zV>f1)&}84~FiJK(H-`Te&L+ zrET-FI2Msrodtk(Ulqecw#i7;^A6lWLgL9X9D_e;#bmGSj>p+7{uoYwA=k)I2=7Mw z1`p}0Z0VSD&de}hExe5BCci=)%uB9uLursw5R>a+P1E`1PmCV!uE(cPAA-nxUA^&< z6o=WIRd&fL4DJcRfL#{D*Rn&&1=co(2jakHfn!m>;`5>8w^o!RX}dm2VXF59S>{2{ z>Z7?pCUTgo2rBo~!HSK>ikl0xV3*8fiOjv@T;xj1A1luRHtOK^w!a^wB8>iQj&iL= zBW$-<9NWzEAuMFRBv{X8Dx=3AQ(1`2V){V_dAcA@$O`t8%BJ+3L(&0*lwE%K?|hIt zX~bqQecH3S{MG~6>!rD0r%pPrx&lWV z`mjR@G@c>gc5;F>Sjf$0iWIv6f>`n%#|l=iHou4z@&F`>$lg?z39DyXE_!&}*#i&T zB63G;x{nYT^PB++8CZpqtQxO-JwXTGY}u$mTc}u2o$I>q(5S#AFw-()vTaOxF=sx*{b7&H9k-TI9aW+NP9uH~>?H zB^$zKIY3L@elbnQc?b4MAK<*?rUI9mJr_nxqg>W{k=jYHl`;gS*Px^d#i8Ao*d}7Y z*{Am|-YsG3)ibt9GY!|bb|=OOeJI$|}(xo`m~JG%Q*aYxVOp!z`B z<5{Vi=(s}x5tT$W3g-I&h}_gfT@9m>{7_ycU*;8}F@u7|var}&O}xpvNP?Vbg$DSU zJsL^Q4)ER<6^pjgEBhdbb4VBK3j3%-8eUL5X>DJjzRtIVcBSM zSKihCaNn7`FAV$6B-^HiG(@`N?T6lZ=asQ16$fdAJGy&*SC%TwF@J~p>dg{z40rUE)R3$7`!yGe|Oc7@7e@*8R z4as@|;7%)14*C`UN{%l;dx(zvOqNVE{}Tf`bPOib&WdZk9XP+2kEV-+f7@7s%t2~7 zo%>0ki98hVlzp4vN=cLZ%@=jqu04r{5x>T-`YtnOwf2*yigP22&ObQ&>!UA>G`ECf zn$wV`DaqZvzI_~eUEdY`oK#6~kmS%;e5=;S9#B#RcvC0LCKO#i*IBraZT?l2*g*GF zsCwfYYvPtNTnXLR&>tC?Q|Vwss8rcb9z!W)b+F0x&qyJw1o=)uw&13t(kD~H7|iD! zfuW(wRgon>&{o5P{={6V-pu8qTL{h9@S#d5`?(713BhWJJ%$5*o2r4f$=6h;#JCv% zb;jqjhA?A0HZ8b)NH_zw>>X!sQ}DNL1@+Z4E-~0#OorbX1b95zu@!|&MX37e`V^cI zk1&7X%=411HS1llQ@{l=@!bBHq_7faeMj?E#I2+es#a!XuDVbg{o1Z$ zGcXp4?AX-rFwTzIJc)@33irZZ&yEL|2jK+qo8f?-CF@=R>YDVd0P+shn!G=OT56qj zp?Y3gnIv$3s4PV__CGN9FkZS@UG{#KAeo&thav|SgI-^v+(*pzb&Zn8k?m#@lNFmX ztl~3X1|cKjiqYt_*`QJu8*sTJpLkt_O@3uHE(NP&GIk-(O8e)c8A?*9Pb#qdj> zV=Tuzb4Tm^j;Ty~#;xcSlvvW*+I-+|k@pg#e}E335}ycl*oMG%YA%*M1#$tzgJeSd zRmwf`KyGcz?c(nRvZ5FqB*~>;Mj~l0O1>ZQ3aIE|^ZwWRGa3bO8AYaWVhBmVx`wvf z=^LC?whN=e{~M^5C9*;=w26@QnjzaU^Q-$iBJq_$uST60_ zg+C!YE#&_QY0CkVrN{H%XM1ONXMr0L?f72IG2NE;1*be4uQtxxJ)5Z{Iv-9l-6N{) zV8K3W*jrmvb5C^!FI9C4{~^{x!~cUw^X|lL&@8r0MtDjd1}#5zJgoGcpYXrDRGhuE zZB^dCnA~z+PcI4IuJKzxS3eBu-#R*V*Dqbs1kM!B*+802ZkzaDy3Z6o;~TpFIo{~5 zbblIpd@6j{YJ7I=*jl1m4{XkQOxXI5;jV@Ki+dXt$A#RFWK6z&TtZ)>%)6G2lZDa= zn4AZ}aZ1s7-F98R6aBqiYsk*7cR^AP*g>w1&f~II2%Bo_Z|?0Cjh=zyjB{h1m*2&A zx9fV9r=E=;3+N_Bl22FjYyH?Ypx(t_USeX`zqKE#PA}|GC2GE+XVB6IB}CFu(=jI$ zl{*_A=jusXIYt|Z41x+4O>rCtPg)fyT?hPvYh2Eq&o*;h|B%^JJxo>y&F9=N@U_pY zi&s4ya0NGtJ{$yR)k1HVrXT&hXAsdn^I&^zCJkE-HC7nGmqmGA){paP?cSZc8yk8r z{Nhld4$*vg_zKU+{V`%=?Bq(A?HP8^ebWubez_YwBQo)u&0MS=4`xH8!7Isv(?DB6 zfx~Vg4*KcO32_0%%^xYuj(Z}M{;{Vw?N3$J4o)C&ncld2A11R0I<=3fzu&++#I742 zldkINK6X?I(#hdE>9t`vlY4VjXet}n*JH&6hj-T=@Ui03ap5_0I4?|cWgzzt?S68i z{*PYXSOmf0-&@hE<&-#DF3Ti@(3Sbzn}%mm*(TGg+HPOOzboG-@nlq<-af(lL<0jK ziJMyn^*t_~J=)GZtJgLTByJDYMGlK^7Jot4nZ>p-H~z6zo?S5(?JL7#e|VnmCSaZy zz5vZt-*b_rR`r+ml}V??+v~ZsE{|yvXE=Gj6sj)ir+v^x&f zoZDy>&gc}@9Gi_^{6}P|=g==WMtY>tseYUIPud&TcW=1|91O)({@zF@%|MsD>)lWn zJ^g)LR08fV%l>WT9_QGmksG8gd6B=7H-jkd8)d+G6-f(pGU(0<#p$oO%Ui(MRiodL1#Co6P!b=6b;qNCnroA}$uOh&hx_ndU zW@NW*{G`p(`zSb|!?2aNMrbqk=N@6s=#W|b2}8Gv>jXZWf8L~l#@fRLOyv?;zlPN! z0=KmXm6JP2!ZI>HYj*m}Ol<`jJb`!Mx8Z*^tI%Awi1^e7r}VEcI}IfUl|wbgoXyx` zl0NiR#;1RXfCHdMPn}nDTGB6<&X=8&F^x~HJy7b~mSGWzA-kdVFt+gF2wmI(IZn5P zv!6w=M>&!STnCGSVcBc$)8cCe%mBpKlqYbm_o>qvzHF;1#NXL*!^jml;~j5Rw|Eji z8elR|^n~~}pWEkWI27Dqgak~7^u}*w5kbB^-%HGiZrx9PcmLMMW8mwitV;+O5Q_Vc z%@c8_ZA_^jpqmwt7XE5%l721pC*9mn&hTP3E~p_`f?n)BTP@wppOSZ-B}i4IDPBxk zrr-_#^UhoSVO>IAI|JuVPp&MYZ28l_v;8F9cY{xy3Y0?9xM9OQyE8fxmWo34fVabW zw;c!9d`f#Vw7!iMMpcrWBfN`sn@Ai_-ju=j>bQaQ{mfFXkA)cd`9l z`>6&uXc5}--9DxGg7ex24z+SzY?~C1a1izoH^q5>LWJ#RpRa4(=&~;DB=_>dNTi~A zz0Bkax9p9iUv-99DC2>SW|X{-@`LVQuXB*z64Q3}_AIy13QOgZ=I@0^|=8JiqandWTQ=8U}n}Bv}+} zDP%(_oA#Y@_wyebayW{;a+BdlSP_PoauCPPisr$Qc7G%0(fZ7<*%z4{f7yWIlI5`X zaF({&f8z|c@LaH>CzS`gDSu{PSZrScCi$;y59arMGup{R>M96EZOY|I{%L4}@rgO;+)?rX=lQwo>9WyL z768Nli22ve9sFo@gsUC9v7(&5Rj_XMk!KapW^z8ER68T^59-$2tBJuoTYZ=7uooJ`U%BxymSreP z%KEQoc1Z#NaRJt`&sBU+CkhKZx#l-a({mCj=%820jj(Vhz9f}%Z0^;SI`e0rPu{H* zMx|?qXQJ1!lgjjn^X$C~|0+2&#_3^||EY;ImCoKIktDRj+0#|=c_zdirKA=!q zF|++E;JFg{{MLcLKXr}JC(#YD`^Qtl571{cpQz_Q=(D#6XamIT{4Idj8e*=6lfS@W znFVEO48m*>-syB}m&TXphg&74(7$H!Klz>$^Rlo~8C$f5*|Hp)07#&CJpsix$68@% zQ}O;w%q_}$k}9qJ%AsFVKBu`lhiUSOtd47PHKrPlCivUCH1bG|-qY4MVF1etM5ZEW zf_(h_n9^UlNXoi7^}#EpuDVPl*gG<5vCECnDP4N{c+97^JpxkmLuze^tN;~@y}Nsv zMlR~Kgf6jE+Nldn4ra-uUt`bU#V}pZTUL>!jM_p=61+$HFEK?G!O9qi)q{x$0NGob zfW6)g)(y<3+Rj@M>(nZ%c;r|-VWoCL@TAP}0!m5xYUigifAVA%zd!KApX5o`*3%(= z@R4sp4k|GQ?^R+DB#!xMO+Oc_{$uO|l-)Q9hqqmyF0`a-?32NPx>CEIkl5ggV@|qu zTT+|RR)BhHP%7}OgYFA?VyAE4MD05y-i9tj3lB4c5^ginS^uiANy)W0fC@TAW2KV| ze%&^5>k_(q=SO%ZO{4sDjyOsS-zRHY6;iGy2If$*f~{gF)(72p{__1`4e?e7kw8ZYwpBL8>)i71ibwDKD{lBTcMOTp`@P}l^LI7G2$!wZt~<9t{7tF z^)~+cLfyPvl4xvQ$iPNFhu>q!IX9+y0=V<*Kla*WsTfJ5>%@}l*w^?W9AY7!ACw`j z)SVpbD`;$!Bia*eWiQ@5$yroVdwuq=^y`UlcK#$JQ7F)9-x>p6xAo-LZ#cZ^pZuIz zMR_#Np8m#=jGj+v(O73+lB1 zSGi_X5#8B0ul@BtKIi9*rt3}NnKRiBy~4hyP88f7qhm(*(fQNFpBM`M?cNx+&Sg1C z>np9>*4N3Og%cOBNB`iJMA9L`fiY;UEIViArgd${9S0tpSTK2hOnWACoH36{#u1_h zd{SyqDwt^df-L=fVf?`Ew^2V)yqH0Tn3ciZVNq)AvJiyww|~QR!}jGO^6b$dyW9VF z5q0BYt1WH+RsF!5z)|Wtd&EFZWIspOK#vYpigm5WynP_5hqaVa4B;I|G=^rMcAW`~#RJ5PW5j{ha3}IH=fO6*c6%#2=-= z*;{p9r<3M`I+c`jI2kab$D1b^70$XFfueh7K-PJ~@~g}_oJna!7(v4)PSw_YlX;ED zJFsvEPaory45p7oR}LyxISO=|qF~4i^vB;5=dkVPE7rFx`r--FHL`v)4G`9 zwPcqvR0wg=&Gx^Fy<9|6b-Yq9JRTKz4px74zZED}GVFs~Qox~%!KxVOhIAz;IX%6vUmBsR`@BpBZIz-8}#A^VPZKdz(*I4--uBwZU=bYe3@y&((P-%J){Yu$DN3S3dNgzOxgga{J zn<43Iv6!v?W1I685EURp1e54kd}8lI(1kZe%>;eRSY`#?!2LveV&DtN{u+V9vB%s2 z9>0;szAcms@fvzH+1pa{xvD?N?6&tzGPs!uW9}$2i~*ex`Un~q zq08uA5@cfC)qPY1u$AOEpylL$o8=Lq54-_ERN8(b#R!Syj?2%uf{M)R|0Jpx6?H_M zjTZW!Iw!p&TIC`mqAtqY$eVk6Gch6DRU9qUySWcrmSwbfL_Oz4^Dx?bl*aedJckyB z-OOcV8wg~oGT!YjV{P%BTvkXS8Y_Y1 z$y;o67aH(a7J@FFmx0zhOEAJG$!u9|^Hp^=BKB>V51EF0e&a_IqiCka23?bxF<$}P zz0=JjFD9XiZN(Tzj( z&Y;hmO!-yu0Ww<+^i3&ozX4NARtGU0!2ngOBDz~;b+Dj$yhN?+LS8@H3Af0_HqgHO zEdz`F9N?+c9~kEzMEDAH4X08ce=I>%0|uG5U+XvLg#@I6qozR??|G%04xOjJ!9naJ z2j$eK*CVQT*#UlnTU2>Nc^~Q-ST1UN-USh5?kGL`XMHJB8Nq?IEr!b=KL0hmK+2IG zKJ?E*+4}Nj5v)9H2_nEvz$HaC#{7ZEu0@Z-Z_lV5o^B@|20w0i7(x)~rF+TPqej4* zoO<)JnB(a#iNeTefRD`SU){o{Yc%an(MZ#y`KUGDeM>K*#4yGX?WKuY4nxsRSCtO| zFd@95`ctJxF0c1s&Mo~iv*?OM*3e57x}a`uMhS;zD^ZdXBYD9y+K@Mxnm8Z1PtU#s z`P$TFcpq^F-U#a*+N65FgVEkDQN@Z(8}2Jf1uT zAT^Osuiy21pTOFW?*p8pEe42SA$AHWgZA%`yDZ<-+sg2+Q<9zj+~o`dDzI!>+>_0= z;l-$aQ+O@r!$tn3Xj@1@x?3oHGpCV*-#vS!UBvwzU&Z>UYrTVojE*oKZV5wVsD_0q zqtuK&%Q1;Xj#fpC!02N#c>cnX!!5PN1z zOE#bqpq5y zDwH|j-WBSmEQ6|fq&H!89IlG2G1SnzgK~5Pa(ESDt$TcX*E~q6dFQ8Vj%dVVOx6fH zLi**F7(?QJM1yRtpVRnSWJsnknZsj4+4G%MT!r!V(zqFDSef^I79696orn7G=3gZN zDt6n8+G#wtPN{uf(67%;9OMh-1R^7~Q9NFiMT%DlEG{QRywk=62?4H%b7-}xp1zi1 zEsrIy>#Rk1UkDndeW1*Yr-4oRx)D81x)zS3Je}m3&kMBWAFoJFwb)O-N!K$~iNVd; zNR70Os2#lB`(Her1zQ_!8>Wl9dvW(-#oe_y#e%%JyE}oR!M%7XRv@@*klED|&lXiL=x#Grf+=I*TnAk`&WM zb50eyybLj?DQQ-7Y9tgKFe`<%g|wLxb@5WP3{&H4a1Jm$TRbe9w$o!w5N? zA|mqf@)2bH3}5ZkN}TXJ+q41`pOsR{C;|NI9}-6jg38F9EwCx5!qh?P^>Ap3Ed6mK zqW+FzDb@;Yt`kmx0<}YlEh+sYh4Z~0>wHp=0?PG7XXAXUBcgOsAynL_;M0|>$J`aJ z@THo-dNb>MS>Wi$V3K6oj@BtES2YetH^$YExH(TIPE6I;^C3dqccabHx^$*fp z2JRtbVXMp&^0C|^&SY?m(^9BblT1?N05~S!gdV20F*U7YJN&OlR$2fs0Vt9z`BKN#tgk=ogxz9?9mAQPq*EmaCIpI zG@5Q9%gUZc+7&i$;o9CIBg9qw;X+?U$zP?w`sX;am$vI?V%kLIHX8UsbYLphG5=La z#x9wcGk6xvGDBn7Y@cXC=))T;ohNEo)q>SXr$X{+pagU3`}+Q$`%R>eG_rn$z*<3Q*KcIt?$i-pSgVs^d5$m-8vl1WhF*TjRm;c9_{Z+ z&d`!X%L6PbkHHMH2Op-Bz);Dp?Q*eQR*a=TNdjo8o9L3dgY7krGl-mt3$R60M-8^6 zaT=ynYx?KU55)v4GE=V>x!2jn7X?Q6oc|pesD!b9%R1sPUf?U@gjWuUH~eAr-uahs z!oAtM{!w37n`asy=nI_f?4$bCI`SOQH2{m+&-FUgWB`4o2>sL(OE@;vGLK{}s-Wbg z$7c`WUp*Q8*B`YBB!vi+AdKHYKF1)o1O4HHfnuzxymCRrgc?(bj<@1ux|1Syw;&n_ z2QMURhqLdnDMjQJR`RrtfVXc3$b6QBxFe7<}3Zj7Y;48+*@l5r~ z71~yFHtQ>aL_K4GUa)RVup55W>GS;Oyh*|tK1J7c5+4xW@-4|s^fMZG6)`ZAL(Z(k z8SD8{cjNa5Bv|rUwKh78nGV4mFiJ=rH5ISHWDih{6lBWeZ==hL#NW1YXr$b9TKgc& z5kYT|Qp{w>N(aCHQg9RemctQm@et7;7vGhi4wR*36In$+e=bU{~>(`digCk%QX;Q2N^d{`Vp*hLstY zzlGC^RzAW-7!x~V9%GMx^p+Sq3&LD0)RKyqG#XaV0O$n?`iQoE{2l83U(OVgF3L}; z?$kUC*n}wjq9z}ED$Xu}`>h?&lbb^n+M7&n_44uto3NFiXL^`7K z-jhp{J|&aU;R0TQK0Os)HhDgF{-h*9LiM8S;b)gNeKwwv{nt3KlMgsltcipzzD~5qZ5O3n}Irry*Ks^+;gUta8QXeH?yXYz(8qLSfCKW~gv z_p<{PF1c~y^}cqVa{2lmkU}OVSKjIaXX5BKwB*ry`vJ*H3=VLEbCdBLpdoL= zMbCVnt={L$&x-MyAB(bp#41mf`oOTt=E{tw(pvWb9catO8;s6j{mS3QJ7wekV$_+L zqoE8z53k(De<%H8=m6ca81##2me z3`Zqx;MIsM(B`*H3qB_kj~l!WA=D!WXsJ_pfCiU!n)2R3^9kx`Aw&#p(_RutS_xW5 z=O*-)cmtKx4Q)0|Wt@8jEyeO*<_OP8(QQZrou!soN{4{#keDZ-Xo?@SrOR{*X;P#( zV}2|}6-5^})4|PZ*7&KHpN$aRzdEkih7F750+s7JdmwtE@!=$_4-T6I03IRO$OvUc zSFW3}8K3adZfj@-qFN@OfL6Z78Aeo1sB4LEf+XqXW|<>klims4?|@Ry$saL|-<*lg zXagw-vDa~}oRCn_m?rRb2j;=;9>GVyglSGe%S~9EY!}Wg|J527Xvj2N+!wPsUCXxm zU_8b-t-Su)7rV;~)j$Ko^TFBKGGX8Ynmdu0+ANe1pk!QItU(M0lP4c?zF470@31;@ zG1dH~#mFNf12pnN3xAEUh|B#y$W^95R9T2#%du;k4o&VE;0K*KOPfmJ zvQ<-(q)gAm&bHQLbOOoFCba~Jj3Yq5@W*djXxE1)^;M2eO0D1KT(tj>X9VqyETif?sJB{*hKq-f$+XN=j zBPl_UW%Y%5>eOFM@M@$Z8Bq7Hq15F#KKQUJ9Gwh zF}Y~im8S>iyv3R%m1&jNS@aH3%8-w(?a}cr>+Cy$#*yP%bnHh%}?3Ffsq74B=K81&g6q=ijac4;{gExaUNOjN(lL19#U;sN&2)LBTq zIyG!COTdp&=N)uUWmsk#!Fz8INSynAh`F* z9awen=T^3{58}8c zLI~JJXv(S2LGG%chHhYf7;g2x65e^H6ys}h4r(~}V_6$-h12|Uti1~)LGorT&-F2b z?Es7r!P@$Pl<7T9Ds9cSIF{8~a2jwt4)zCTS}Cgp+Qyg(qk6hUu*ly7E@=_Bp1TBA z)xzY>4$iQmFYCiqh#s!r3{n`=i%6E&Q*|6veeosrRGvps^9n*s8gX-*N zee&0tuPbhP;<<8S>mNs^kO>-xDlF|Y`rXa$_}lLjV6;+tg&}2Jzc-?rolnUIwJrUW z3_RZ$BlkVwJGv!))EtjbIi`7ZQ#8AFMxB$(yr6XIV#2N_WRwPQxaWN4?SG5PInJAE zJPf>=H69f+;E81G&&)~+mG8TM+ASqbIYdjivGC>+?`Oiftig`)IEV$*S4BUi=rf3( zh-04N&L0;VN~=`7FA|!}yAeTPx$jqm?=~{a-DEqaX_ed-z?VOiuwFG;scacB7Rlm1gMm&xv&iIV4h?=1sJ#xJoEjIt_kP-S z1a6xaEDU-~HF_h8DoC3-!Z`TLp7+cUkmyPe>;QS``>EHr_@wqr12wiZ z?@SuSH$?pv14XIWr7kY9coJ>(9A18fM38&kv2Q;OW5j9n4+OESB9g)uWu`}Z54@7X z5X=O$_KdK!EzlX>PWjK$e4eptIZEF|km1)aG%l0er!VtMUZRrE1?!E4Hz&=|8f1Ma zQ>Wo91Zd5t0Z%hM{molI?VaY+d-_6c3-y#pr1paAYFaDV_}eEjX4v(${@ZZwUBB%P zVw~OSwOO|MVw^uGvvvr0{Gui&k)WaC#UK&Anwf;!NGGWifC3(@0TQA zqwb5*RhW?GAceicjY8e=p>#2|$k_t;vaLCX=j@>zcGx2T(Z@4My;SWzh*h=@_ufM{ zX{a1cPNV8*nmJ4F!J2IZs-b?2VzWm~DbSJ>g5{XGn$h~J+sDdy!Jl%PQoD;=upvT! z$wzcPVKbu^>JpQRq_rP3rN^KYgzM{_6k!C@Zpkozeo@P^YVC4h*NLUcuE^bgN2ZEP zk>m9HHyuO|>ocDztD$o+5K;e9HN*GG7hl9sQ8Yc~lUv5l(Z>xXJq#%^Ly5s%kkp^-gq2%xke@Fe2ZHVzlB5u8& zL$j4VD~Bf{*JMUbFiIKU4Hc6`v=o8YE?ep|MD2^3EBkjIwv`{;s9*k2;HRIGK2mhE%i3f|^nclQVccnF7BRf(SZxMEO7kcO&ZV_egmZ z+&P$I$#C)gmuAkmbru1}T`ij0hepn%yc21gIV&w$@Qf!1GPL*P`FXC5bU=3KTrg(A zqiI*^G_=keCQk4pM$w>w_>mG!ZyBKTf^y!`+?@HT1%he?7)U>g7p|~hd-V3OM@#86 zfkihqJr%k_WSQWqgG$7pl+wTgVIK00wlkB07pS&|B7Fc!Q)U=KeU%of^Akv2BOl~ToX=G z3mWZa)A#D3(1c=v;fLV}FVO8)IK^<9T@>*OmewXx=)-zprzl11UI*vYJf&Bkuooum z81e`Pf*nLl8o3n+n_2O%2@fM%Q~tLYrZTZbCI)nY4m^#zFDk1Cd^RchK5doqX*BcE zrchWT$6dTqwB>0cH1x;X)ia5KBVQ?^x0Q`ir9W?S6qU_h51I| zC*6_4_)6kTrx@EK?sx@uikgS|&iwTyYQqcD zy)z)qW-}3C9;PYI!Mxy_tv!v6a_2b{L>AR>Z z%ul*$oDzBzEiDeBe9Ss+&djb&MJEs5=I%($ubV}GB$f4|J@!lMmDV_n-k~rg>{G+I zN{!hIR8CYXN~ujxO>?{zkihQn$B_{5DM&y$3|#s6GVuucixMl@6e`Cn&N z?SB5_Rxkc9u$DW0+hAHj>{3HLo_JWns%_V|?E_q|zVd6!v60E5!KgqkVedTk@4Za} zepq_;N{^(zK80$3>F+@Y3g!OTy_%Zoee8{LJ946oHrEM-M`UN7H6B_4a#u{p38U)w z`l6#WFX2mk1EXa#v_N8KJS8C)pp@IBFVu32q^T-bP0wUr#<*m<$;=xoQ0#N-Rfm_w+3&lr3% zK5tLWc65_(r^B_-ovrhw+{loWYM3?m{m1Curj@jqk;k~lM zqxCJQpWC1L9ZF_kNY*s!-fJ~8n+@-tNQc|DExk=)#79iQ8+2+)R+;AsqBRd=6%-qQ z`$mZ5pH3BV=np1gDNB0X5`G-I=S8#+Etr;)gO1>fY{@Q>2YyV{W+XUd?y#$p;!VcG zz=~3d22MmZ8=i&za=<7zG!`lYidoy+L8SN$+|kov2=047U*hy7UcAI~N5m5Ff7QoA zkFKHw0*K0(jCSxwIUUs{vmau=wot%!W&E_vR_9`gA2w!D5(7i;aMq(cBz;4Vrtas+ zFvvGxFac|Hxa?Q?LocaolgH9eDW2CO?&|RC3WX1X;xQYbA5zp<3H5`B2Z9kh-8-ya z=~#uPFr5*MRmgyko!({xFELQ z;^l|UW?;%NRB*`ZOPlD6XZc-6+RJ5p*8}Q;s zr>vaZMW3+mS3CQ0R>zPP^gTq70PxH~M7hl2>4S_bWdkL)dA5R%SLyT4PKXc+%aDkA zjX%(*28Ct+;OgJ>_!A2&eAj;^V(ER7FqLJRVp+Ssv}_DhHMU+S&?r5P)_Y!~bcg=~ z%`=c;fMUn68C{$%V6M!?Z0A>|wlGgmg1gc=o8;Zj`j9J(~Hq!9U!FYB5u-^?zh2Ha~tU3_WT zCQ(-1Hz}`C#!eADEoMGgEN(Cq$ULhG-t;9yQTA!ct&M0ounf}2*n1CfW?|8vP;x8F zm~Y-BRAUo&*y^x{o4_Sqp|$NnirUA-4E|?jI-?&4;weZEH=wo?9gr4cnk*=N^;1H% z`67UEfXYc*sIZidGA|_JlsKYe+#<-8tOdch?lI`J zD!;6*&T=mtN-<;Q)}#nLt+FnY#eKIu>Z4$e3ZzvyIurd}BYMcCSle8TH1H=FYMYr8 zV;HVjHmYlW(pC7u@P`Cqj&BE;leMP8`r|zJv3LBIiEFTa>yYR4Qt-VhwTnhgueYeHB`1rBvbP;Dd zKRVF@CsSscwk`Ct$(LxG?w0;Gaz%+Du}5*&>K`y7{3z}WtAe%=RzqF%(@XYsMY zXFl4Ko>Y+!-{!sTCnpy66^|;;PVuQG_5Uu5YzH49pb7rgb{pTS7ncU6NE2Ttn4!Q( zx$AW;(8t*14!FKA`M0$hi^8vwPS#ONp<0YVW0Gw&4~O=cWX(GQCeB_CdJ|~d08css zei@Ey?}Z2LJ79u83;r}=Ppcvy$s;cY45+KjFBW6{SKE5B;M{xHvfdi%_BqgDFwY!n zey4x3ar$)J>fL*=IH}-@4q508xglR~2{f%*yW8oVj8m~gM209Cr)F8e z-s*?nX%e^DXM;(#FOWsNzr%Rz;e4U{4$WH?5xM>&%ilFh9cHX&KI^yVP>v9<=ef!Jo1Ftgdx zww{|B*ZjkQ$SNcW3&)7myWW2<>u&^h!AA9_WsHQJL#K0R0_W~P-l9y%o6WGp z*&Q`%(Pu1u41d9V(~svN3Cltkm%3{)p<5Vc2RBG-ysm=+XN}Vm>u==l?X?2j!`1BC zyt-&jAk1(5v(>%1@AK1}wY)d@Xsp+iyY7Wv&&@8{hpW11@#@YA>ksouUk0IKkQINJrn$?UTt9w* zKBnu3P80r{R5}3u4TE>kaK;V!_w@Tt6wc&L zdj7x6LFWQYglA#){D4E^LXH!B{j23aRMEV_tE6kLH~G4;_2k} znhEK*76eBk8{Gzsn55r*q!)lpaH)Vb<~ARf7SEOfZvE{_FHkG4v5p1p>6)cKXT3kTsYkKK`^=%@Dzm=xlitB zdGe4~xbO~6BBulPDzBTs1oCH%>yUWU>u)_ZI6<2TKwJF>gQxfeh(a)_pNLJT0JqGz%_^qJUjKk%~?w$vuc zr-SzR0!_JfjO$-+*L>e7X~D`(^z8$u@C*U)1GbU$n`!xtry$ z{wVt4Ir|HQ03^}#E>B+whN)Z4ri%H1+2u-7MkLMB9r@e$z%2BjZ1w~A2EU(LiZ_j= zx^T_y=1~dt%RwB>JCwap!B>;J z0EV^qhttaT0gjbZ7vD~as&TkVzm;**mTE}G;0SDN(`{>lEI@$*MYC_b`|Zydyj+TE z>%nOl9V`{-VECU;NFzGf6c*0;z~ZXoD}h)g1`&yu#Sw2R%;#A{R5AK#gP~n*rj6lx zlxiXId@=CE8XPLQ%c^m>L^G=LSy7D)$XhAcbD0pFqlK|6prXyf6Se;0eL7iLUa(q= zrkNBcy@~@j^&Wr*>>B>&fw|Yqi)=$cvE=(N|N7=OtLM7(<*yJj?=9Dw`+~x?p7{%r?`MH|5-n((MDNMZP|` z{Spc+;HAh-+6nfXP}_gYee?v|7$lTCHm!XO*euvilI&UYUNU)-A68i)&ea+vZBs-` z+;0ijD5&C)@4Q)vLS7X|gCcVUop}}>59^A85;j!~i*Ssne1Ex%pUy{s!x) zI5vhW_x-gk@V6<<;eC=SxM99?7{|T93EAJ2*e07k@GsdNDBeU*DujvaUmhchovvD8 z<4D23e(Hv+f(YA$gu>oSdJ4}89&wFSCbuS->Jc$|iUWBNQE6-pDsgS(a5hLxHLl(&DE?+Z+$u68h54PZyQ zJV&)mU{h~~o|lE*fAVgD*&Yk~C5IZF?62j66huBhJya0g_xn5ZgEwq-N@$B%P1IVq zP`u=+&+BY*Yp1CDlXlkHzF9glnJPwbAS_UK8YnKUDLu=g=9IsU*68~%9ujFEKYwgGoNNro;ik`dtXTeTEN)&|1}_FX3Zo~w)}%SWddLTg#v>rwY9opFOz$bl7S zUk2E#G$dG4MUjPdIK%G%VKuwdZQ?gDI!xy_3`MNyoNv#zf|TIy6K_9Zc8BZ47#>lc z+XX5eu0niK0-lI{uJ1XliV4QGl<&3suYjDloYjlB{r+1V&Gdj`&1cQ|NkCY=Qnzi7 zO}|~`H}O=g@B79U%nwp#=vy62-7e>(q!`6;hHhexnNzSvbkpghr)hqlju~*5;Nn$sW?-uYBBlE0t3FeY6%!h60vPk+{Z@_^Up= zCc%fj=uE295$+l7LMUliz4UWFjA$f&@+2a<8e6fdl;B zL=W9oZR%V?*mRmnKepAH?Qr>f^(Tm@+)KwxEPE?-H-HPf)5X!)a{e zang`6qQKln6rHsDDEI(-$!_*bS!{~Gg5}sk`9OKEV)oe${`*gKAGoI;r|{BQ0(a=C zS6XQUi~<=eFW^>&37{AW$0B8*PzsABXo2+rOYnqHy&g`$OvK~-=G! z*z|vj48_&6WjYSjY=}p8NB##g)a<8aX*Uq}t4Wu~T=k;qi`4HlFG;SK=nvjlq2Z^S zb+!Z_Zh@EqA*<)EA}NJBHxw@`B0o71iLH*ck^*#VmoDWF24Be~d)@_G2aSXymN-5y zYXo!e-Q!VT&E7Ll%@cW{eQijZRlH`kqYjKDq1gZZPQ+^BfLTQL_jc05C96fKF{;}= zmLGL!aeYf#eOO*9Nk9T9T7QMT=nhfI3r(V^^8j&4i5uo5ZrEfzwXGdM@iWCAxSy4T zCYYH64t&DdQ1jNeg~T_73gl{OH=8dad=H{?oI03zYYT&L0%&|<#28ox10TnG{Fy^D zUiSmrk*9|hxou9I@yndrY}4ppX!oy~De;mmy@tkDhsrNFe7?*4(KZ7pa@k@&kBm7C zz>bF|#$#xTunDNtbJaf5Y;n^w(LD%>1;Im(f zH0$RDo3QziZtXR9MNy=CeB*w%z3=)5zr{`T>v-+|4%n^Mg>0QO5gHy835ie|<~aED zDb!U`uC@qE3nZj*%P%GF+BK%H`Ad=$1cd@+%{kB%_|wx7kPtuwBE@7yimNJaaHptv z0R$BATqE2rc<+!Q{B|tc3?tTd+e+vSXf2o6AiN>SrWJCr%oy)M2;%1P9*n3wzQTrSJmvJqDf0Z~-<9QhLncEW_wFexg)lZ&1f^b2#KJDol&9LJ6mW;@EwLw5z!d z8Tx8)dj2#U7OIpM_53n?z`ax&nf}UnOd-ZAuRj-l9q9;&KZA>2<>}y`!5(I+sVONu zpkKNWW&k9}+AvU*%N~mH30lFwl=NPjhwYDaZ&UbwES)D%n;n2^JvB2=g9E0Tt>~oB zaW9ytzTf>}ZTJ{l9sAD%*)sFhm|IZ7pLnXNhO-LFeUl8%i@-qe&?CdP%l?=sCBXCi7I<9rb)i}j)W4bN~`74IR>fF)<1y06^ zln=lOfe9B=G|R9_sKxoA%3TTiK@65Fls?GV99u@WBfUVM{#b?kl?VQ;3QmM-OH6lQ z2Hw?LjExj$xQcJv#0E6W57(4~a_T6k8RkUQhk9(u+uqQ}Ia3U7@8<#&)#=J5eI*Nr ztuRSDjCux>qlX=QOpDG8j`Ol7YlfR;pQ7`|5as8`;DwpM#b#^=2E$1Hc~qgf>z6$j zpBBhdWVqg4-NKer1PlkZ zJoD(4GldmX&GAE(`%t5p^%%n5C9AurzN8*hJd!tfnkJTGL3`<@PQOQ*0IM?17%C@N z+^|;&E%lcqQ)--oEXv@TOf)s3kV=%_Z5eybges#+$Su(NBK#0nb?Xw8>2PE|BlW&DfpTGV-l=iGg`E=*tJE_agk#c@)7g0LYR z>*0y;aLyIGPxYcTW(-PhL96c{PK21l_3?sJ!o)T^XJcpSp<2ch*_V^wL}3Qj^^|w- zCK9nm6@W`)W=<63EccG8?+DV-wWCiQgSZ~)Vs&55pQ*wl|EesojR=^W81ZpI*Ov8* z{B4rbXw(B8HeJ;$MIDEr|18hdi$a3SA1|XP73isY*aGT7o$Okk6ab-`-9jRxxZ{aUE`%c9VPX^%a|r{YSN0vk{jg z=A@O_%CVk8N~u%UNf_uhflIUtnE}$}iQKvFk6Xt+rB`zZmhwwSgu}Crk>(C00`d5; z^e0-elSa0q{B4(*S}2PzM^lUagwiX`a%29X2v;liDrYyJA=S7W1LMU}d~@J1SHgp5)W(!b0W2*waek4f z)hj2l;JzVrCCv^WJXF3`3zU~`%CA@XPt)8JTJfN2EEE&T46T`{re1IVVV>7%7FlD1 zf_AnQs{!=7!;508>M^p(G8le_oI4dhfOr8_C40L zZZOR}^LY3^>E)wTfu12@$&XA~+qCzy_SlAxd3{PM4-WoBvXg*$@%3d?((G?rc~veW z3X;B-JmIR($gKO;b&D+z#%m7t8uEEnudGPn)#-m1r#*jy@7nQ{k(KNO^Bax&P7Bl$xo29NXVI+f zz3-?N;c^nZ1+0H$CjM}*_B>833#Uja%WNk?`l*Fq0;PzE z$PA!}Gr;beW^CCs?XC+db)&@;>AMX3)?)ze6)JAO*-s5ZHnJ2(r(M8Je(fKCJlfE^ zQm>5u%5|bKtc zHwBzrA31_TY)Ce<=X{A4Txkgf9ivJJ=A_)#qOd+886!)-K3_tz5@xvO@O+!2XObpT zWMPk{Pab|hZOYxF)+am?1-Db z3&~*nSN6j=Pa;`?U6$qqY0!5$sGv6()gyE%HHo7xg}*6z7q9q0Brupr&?!DupgJGF7R0a=){!OU6R=^oyRU_n9sgFg=eFBedQa(cp)cH7RNfJ zhx5aFWhBp&zx9@BY9Y^^YUa`_7+x@Bb<}+<;G}fZ7OE#NJFV-nH-S`;Ko#d%n#V82 z%-U*A9!L_$0q;MB5^<3I&UXVhjAQNFRdt+WE66S`aRbO2Ftw(!PC112BQs)> zeHT%yPH^NEX8qL7_Unk~VjRnO73}iZq53`y->pz{)0x2R)CsUVYbWRq{=RN9Jd3-Y zx0c{Y`3Hao^4cUikMaE4?6^RSW)+nezugQgY5BTg^+7wp`jyNn_m^k&lwP`c&LX9* zW5iKoqr^9IUQC8WpLI^$P-l69x(v^S2y|hlx>9QXqa4ubtdPEY6Ov5_#1$*q^Ej_6 z9k+I^;H!2b9%R{C&bG#B?B^fAr1gx?T)Tx5>*#+@@j**Yp^Nbo!uqhVXp_X!;yY)n z1hSGL(jsSRu_^bZLe2cDg4vCq?JDgCCME4?e*iZPvK_i%%wjpm4l{#!=-h+tGUYt4 z7sN*1g&}Abu@2igPN;R*qNf0!lctPWkycZ0%9dk53Yt}y&rKiEd#k<@WlPQRpFU&l z@aH^Syv%zu*^_#ruv-ZXTD3cS>@aA_FmJ5~Z4>SPa788b|F{pkUhwh<@OjQ%*kyqV z{I=Tp_qbT6m<^*x7e~yXX9IhT255!NB&=8=INdo%;t#Oa7Uv&2) z|FM(-4?L+qay8f%RY)IXYCpurjfd>I@?9^2)-V^Baf=kEQ&Ak~x3FRr29NH(WaI%O zxJV^#^0nkz7*A)8&!E0o4 z;6w+qI$8|`Sw(*{MIWr&l*7%2O3@VArm^I>)V*6(C`=EIbTSco|M+oOd>l`{dn!2= z)*q|!m(-M6m&$s=1fhnS%*zjEUWZ}P4J@vX?MtdLNM1y|9Xx-lNLmKH=DosnZ_fc6qM zrw*k9#}c-0y$RCt*TFb4*_|gbtJ7#d7x*g;Jr|aO(E8hY$J>SU@BJx9xV<*=_k@#E zK-+BFO;s*FpHuM1k8~Sq*f`&jWz6k5bs*XOda_FTQ4?W%5$=EV$iwX-{#2H<{Ct6pOg{F9Xayr*tBl;!7^JX1 z!p3A>Ei{8^B#V*Ok!qMjG&#-L2Lw7g?wT(|6gYlG2O9pU&?ZJ*Axf*kGoJYcCEjJ% zb<0L#JXYWfOfYc=Yct8>8rIH{w6LjqY*;NJWx?&Ut?}UFNAgEU-@$f8e$qHZ;^E}Y z^nH&RWw7PM+xw3LmkJYWBCCQzVQA*)`@nr3hc(M8)A4ssx-_OSt9Xi1Y<+04i^{tg zrrF3+`Cdv^*);yS#2`f)QEx47P>od%KHhLlgE*9~lL@cs+bFLZ@fW(TQae;e>qJyL z(C~~oq7E?>FFc@ZLJlhZ?I)nnl-WBV*BJk{p( zJ%=g;eM#-qQ#=FY;;pc`+3DwHe4Q^s$#T2}2)FvhL4MUTIEtqj0OBwGBc#}Z z&ci5_=-60qer>&G@7y2#$f+ac?iWU?p@eBWjSfx7|-3|%}tgyF&+!n&d?YwHM`-TCs?U-+=#i{OA?|e)6$Z*Bl z(%Z2wgdcVS8ltxL{Z-C?#)>AE&G3HNBC*7FE}HP#j#7KzVtvy{(Nh@_Vv3(S*e;6H=S>Vg#Yon7f1G6XXK5%rjbJrs zQt65F5{+_wrYxtz5lVzMQqd-6Q@5+xB(H7xo-=gr0U1}*F+aT0AlzhuH1UA~qznf! ze*`n>HaA?=Erh;bh0~sQ7mO+~GsB7<^7hD@v2N+~oKo%;R!zyhZHulHM8AmDVOSxRrb! zMRs`313VMka>E89Nhqyy*~b7~vjK#IY|~-id)2iLK2iopW;dYk0Zl+Ipu+lmr7nyw zVFP3g?H4-~U6JlF)R*xgS3Q-YcE>)vt zN=acl1~X+@xE2+k9I)068!(eMI!ltBt?zWYkAEG(J|{{Z`8Cq>)pD4?KqkG zsHJ@GeCujN)GX&?xhlX_Y$odXKydir@I;%LG@2yUS=Btc`oFd7YERsCrU5IbQ3Nvo zmom|H+v<^0#8%;{4Cpn#w7EK50b|6?Qf;Q#oc>Ptei2p>2Mf_75rbUBMm!fNVS9(o z7**UB%UAG>{Ti+}AoH#Sp+ojvc6YxtBe{5Pg#7E*P+rO7zmpj}w(NDyVkn&TL=!kU z-|`>Rulz*vz%jpg3t?j7;S$;F$lDu))5i7@^1nNjR_HyvB%Z3w%kl-%Aw?+IRhI() z3{|Ad7EN3T4wzhL5-#U)cp;*bw66@zS5vwyhOMrqGlBtX$jSLf@0}r@UWud%)#Q%ADmq)%T@?>ko&^8hSu{Z^@`Y2N%!sDK3U>dvuTvl_1{mlgO7O z12XJ_Eq~G&-hT6tNHj<-R=T@Mks{5lkuf@I<4rS>YTq*mPi}nWQgA4*=OK$%T^E_p z3Q|(9o(}4tvmHZ*(7FW2>U<7A;KnbF$Vq5Mg$pOO)p{b+&(ri|rR%#G5#mYaks|&p zqJYY$z+>y&`;d6SaIkLa)cDdaF$y;4IMpBjVBFYbqwq6D+r2x)q4SuYg{8gg|1foz zQBi$w-0ks2S_6#ICRL+Eh*_BL$`F?hyS{t z_1yafuUG?X&pG?-_*~!9(}B#uE(%&e9o$&_%E#>1(GW`!_^+Sv$sV)BO}^e|VY1N4 zv)3qF1}=Fp&cPw8CFpR$twY z(%4@zgn?T5B49WEreAVr2OArX8Ckh2q#pODrqz<}`G^h~PVLTfS1zy+IZG2VcPa~U z_yMkRzYKS`emiE{D-HVkojyX^zZipi?-TbLvk)69_a(i#^`YM6MeLtnv)2ex=KjmV^ZCmp%c$T0 zjs+IY_{UP3+QF^uA^xr^e(=d???p+GtynbB(d|CEd#-7 z$I-f7g%3o zYU2(pz@o>#YLj=A@TuL$9~m|j#$h!R$*WeoG9ubFe*_x`9_8YBVmzhJnPfPhl78bt z^tO;wjF5r-WB);5KWLbGd7NM9-SsxKU(=&}RIs_q)^8mI-jTgX;49CkSCIDfcIL)+ zeLuQrwAB@GRCZZ=qr)9PTnvQ|$xUf2C>`u%_!Lb$>Q7p%OG#R_TYn>j7nBD|C!ux_ z1LomO!tccJ*YU1%Wq#o~Yo}o@h?}qDUaQM*QXY`ja`dAuP$u$J^|N!-=;xh0&{ z-K3z5PfW>uRr+%-vQXSoj)L0_7_3B$fm`5>c866qkaa!NPVl}}S{ zJkUBb9d#C6jrTW1d(+x@L7ILCY^^=MLUQ83$_}~nc3~-Pd?JL>28I?2qj#?Xkz94W zS;0DqsqweldP||=$d_KP$I=;bpvSfA^5R+|dFRPDd|&7kqH1gjXI|FRWyz_p^aNp+ z)TD6?;TCxEseJR$RHTmD*$+9p4#s+>$z z5(=`z*?af}7@)f6L_wU~P219U3bfGX;3h>5zjMReCBjY4ySaFu&{bjkNUM=K_TR-y zpp|+Ox>C(hzNu(Lr)!_y83PLUUznL!O*)e!?^986!aA3x&K``P zoL()tBVar{vgb20&jOQ~80_@f^ta?~D6g}IoIst?oG|*iiB_!%34Rev z%}|&vWvu2CnSF7{A09FDbLaTb&>kF2ED`4>?HdP01j$~o>N9jw?}WDro$gR5(3UWb zoq|&6LANb6Snq2)H>^1)br&E?#kq4oZ7o;d*4d9v|L_0#zdw#(Yo??m|9{)Xh=Phf z^wIxqb=Y&jHf0us`jge&fcL7;5C-Hj_!r?Fpn9S^9uJ8M1ga^$jTNV=RAyyi;m7) zR?$aSi)gtf|LIEe@BUjF77Nj9+W=^BP=?n?c)aoOUz~<6h4x@O z-2ke_$K?j|Q1bIhwA^W7WuN@rQS|W%&1tcA#sedId;5p~&RQOrI6WT!DX56WI!{&SSati+(jZkC|i%SltOD zW&l7cy}dGrOzl2#&bDrda%OsrLi3;tz2o~jj2E5wrE)8O`Ft{g>J>ikuko5ItscBt zb-xSgYx+Z0A7tV73_$u$cl{ITzuq`!uhA}Ts;21A9+hhcIX2IVW;FjQ;XwRZg5=uE zoFm4}DnVqz=71>H4(he{rl4&pX5f4-tZA-K4 zm1T%wOk8$>ckRp4|Hp|Ej3TuTPK*nRM!U>3a(Yf{f zEx1^?m>&W0{y7yp`ocG4_X`2+=wz=LMY(b{(^N_>b~odCJM#+O0jR&vS3#k~rmH^$ zwvU*MJvM~%{RY{n>epj=o@83nC{1+h=lprBkB57hi~)BqI1SA2-7kL7>}3a*lKX(Q zYjbH{_C7(LNCFAYo^<@9OVF#S{NO<${G9kHK%u4RmgHU#Xvbp(Zz?gqdvMv(e|(cz z`)G&5hnqWLt}7(y?ZZ_dp>yv)9|NsbX2q~JJo8`x{??V{8-F$XvHR)ZZ0|9S8Lkqp zCHVqh1$H)yncS~`Tck{Kgsnw9GA=USbe%>2Z}K)iW_wQ(!rDjf{&GxS?_c3yn39qO zeT3`KAALn_Om$3ydkg%=PVL)Y+M0W18RSGjw;xlQ<>p+EVkoI_ILI?c4kcCmRd!ps zUiK0CglzT$=~$I%cSAru2mfBTB8Gh<3e&JJ&ii9rfHgPsCDDRy@0Rr;)h5p4kYjgRKh44 zpM=ZUu%X3*BBE2n*M@({7bhPF+r@9lRAbMX*9~IM zs>7clb|kF(P^P-0JqRgN3Ih@Q+`IEfB%E)k%4GN6mkrOGcP$77R(S@NB`G|+6m;`t z#vWMn9S>*trOo?i1Ca2kdw)LDJ^50llWTQJtj2wDgO>kTki#t&vfZQ!oODdP1ADU& za_d8=PlX?x?%=2^#Rd5Jr6Q-3tYrXoh@aQ8urgfh*~`lGTAUDTjK}l zb)kPgQ;`Gw!1z~+=9U*B4;F-<5{3>6RrEHw{6w;0_N3_A23JQ?B52z5?Vj1xk;xrz zGmMEwL%~m8Xc5lk9t5IajkjHnxXKxyVK!tORkTih-;XVVL)nSg&1yY|MK>&^X?+3# z3Yz2ax)Jq&WgGYsl0O8F*CABP+I9Fx7=Wz^K~-SGl<4Jt-_y&us%viy38Yoq-oZl;y&_ zXS)LZ(l7s!*{Wp|^kdM+3Oo!R@)eeSQs}rBo;9_%rfAbbT6ErF1Sg%}*@IZ*6|dG= zW(wHw_gd2!SKaC4iBm0ez;ZTf7QfRDK29n7UR9%n!h9z(iZt3=I!Shb9)LI1ES&dIlIx~p+qKXdNZ-e+vXB0s(20J-HXH~OtKISEmtSNd z^2^)CJ!Cf;E?|7re~W=sUmockJMWG+c;Rg&wm+qv^_BCBEBm&Vh#`90BhO0+zwzuUzhRo4Hh6Yy>HkD_ zkZ!A^(c5^UuyqA`0L{OCJ;sg?ejj>Y%svasW!r+4Q6`zW7xp=QZjvZMcBHpz-9-} zOlkkC$x*l7iwK!wnz>+KwU6iz-^l_$Y(+7t?cQoN0f&4JxuZXTXJ^w@X8vX7D^Yqt-K?ceGW`5>~lc{RU#!%4i%SMHDSZ*4Ny$8x((R*9nawcAptC9x9s1Gb`?Er zKf!VKl0(L&37i#c_HL5uv%pif>`E6e3vHaFr4vxo0ZtCN;8mvyrLgigClxKSbwD)o zSFm1y?U#!gmr+k=t$BV9GTFc7l5ajO;g`+T&o{~cRt{C+0@3!Fy0*QX0emD18EMr( zt-Mw};GWn!T<;}DU?~FJ1tJV2j6XGef5<_PO2ymWdtDA$1UFmYM6CV258e@5?R-5` zE4Vw|N-M18yMg)a6|5E(cYN+Lwg(o2IiT>6CnGMvi|+QA1DQstpBDU-yaCHQk%EX# zLhDZh_Y}nQBIQ808;?mTcs9qN309p7FSEzJM3q4<$VK4qa9(f5`X7OlHxN++C(sX%J`LB;Cn_bD#tI&OGWw&*n!I2+a--TU_p_{VV zAbdJ5-S;K<1J+LW-h*}o-T(}Mi*@BdBM(3`aQ&xCf)%50M8&vdgF1b}C?1FINIcGd z_Gqx@A4}L0Cgxttp0lg;NG3v&N}Vbpz`O2+@C9u(XR-ZVy$ehbta6Z3&v_oraD&6x zfXg3kH=ru*gvH5z4Bxd|4W9K0nl)(ra7U*`F1jg-HU(2*K5kI<`*gp~WSoCSB4suFgA zU`LZs9uM-TAAvWr2+PjiZiqy2&OfrZ$YzzFaqdyc`vs0LkB%2@Y@Jijc1WmOsOM=eXdrQ<&BPq)p+ zKD3<04E}MXmVX95io8%N)y?%|_Gi57&8eC>D)xrYy3sOZ-BtwKCrf zc|}`YXfGsMa&qBeM%b$GIGBlez|y7Bl#e9|`lr>L6MW{#vk5whlbzGd2a&Ggw~? z1!jM@_$*eBQ3T16^&>o9+*cSv489jZy7bYV0ozzUs?koj+}AJMI8(YUW9n^i@PQ{_ z$Y-!YL~z1>oExZ2S-3ImYPD*f`K7X^N+#FGTjf(~7^6A`)pka^4swG^SV+ti3M`NQ z#T@-&W~ttL?Q0{ zSqbi%c7-O$StGIktS|nkvi@*{6HZO);N$gNLOCL%-AG_gwx^|`D-fcOAEsp@J0;oO*c3p9 zQx+@`tkV0;H|k4(a9=jlDH!F@0vv#HNGc6pj8``H*T<1~0~?ohu;O+hZ!YygX%9~J zOp?!AWu6!q(- zescKJH@GnMF$A!f>~J8UWke}IF(^JWY-*nNto`^cHrc)F`sybJvo;8WL_+mP!JtoI zpHB{3cyx`;B7ya1o0I4n8P-sFJp3Le+ULC>Bg)VjI-8f{?-NZhIdM9_~E2`^USn=C{c zt$=bN)by~%0g7hY+vF059w^T895x(!Ls#iBAkAt&!dpO!qDO?8pK4li*#<3B)=^7@ zYlyt*I`Zn!7W9nyQXFtwvXdey98AAkT0FlP?!HftC49xFu}A+V@hShYM(3gL{Es+3 zJ81&Qwm8F#tMOWKC`pa`;!vDNHOK8M83t=i0{4&jq{KRuOTxF3Bt(kL&WGZNV=RCDo9d1c< zLKD?XTk565Wa-asew2uC>n`QF7gDgk_B*G0Le$|(1uo$gR;HjDzJ1CRk#>Rk_bp%_12lrHWDRE&JKl) znlakMunvTf-|90I%Ga|6+5DH1@0<9cV)SJ_(e$`Y#3#gNQMwq4X11uQ^I8ITvzFG% zfqR}X3daCV)R4==kIGTwi}K+1$RE z3X#wnp|hP;#Fi5Oyw&h{Eabh=idgXRVmq^f7}1v~5P`J-$BtIkI$j%#&BD;3Y;^eW zm)9KdrPX=a&6G!F5=M%?q}q7pHv;@5?_?~=q;Gvzl}`JYnJlj01+X`_m7y{s{)izn zm&wchb3V`26lPxoCI{(pbCjYvp6@FG1yv>aL-`=|<;7$N7GJgm9P_twtm6oe!3Qsf z;k1f)#G~51u;I>gFQ~{>`kL0AC-_e;oe3N_bJTlt_D6m&@Iv9~n-8L@>*gyUtj%{K zX{1RZ(8n#@UGs5s-4ktegjyiE$_5p5i&kk9ODNvxnUpWz!eJ#DTn=R)XdH@^l-))3 zZ_P2vW~Z&5OiFF>sAWp$clxaNF>bbj%eb+Od zs1Gc<5qzQAI@d?mgaUoa0;a`hP`Kw5$^z&Bg<~b+bYFr@L*Sfi%5qhJ z1-!W2x4A)KuSuXIcGU;65z)b`Su_1iZcNa3<&sDZh%`Fg=7{4QYzAfDSQl?LmN)3g zHN`G3bX*|?dhni56v>5&3>lNq2}C|h*pkNW)j4LkRTW2e1luPRii#T{|C2UyhqVY3 z&D+VK9y1g@<5rga9$?17g)BS%hG9|Ynwh`acfDen7q``6f~Rpx#p&w98n3q?Go@BjGi^fAZfNnGp&m>?6i@3Kh2V z{)DIF{aP{Khl%N*ttbRcZDftw-3+Wir$v>l2 zDdUptN?ea=9|;Pm>uxR9)%wq!=^QMcGhA}4JDT>jTVsJnV7Rp-1L}T9cQA?{{_6%QEUT160nZnn2Y1l6@Vk>EL>amU07L<;Ut3->3{^jJzr!Fq*z z$tuXxY^<@NLkA5DBjO6{{#PPn1DTZ~UkXkJ@)>@;{S%3!emtK<1U5z2$J>99Zq=F9 zjLN@4Qf0DwpGXnB`=WKYbJckp#O>XxJv$-l@MLbOhTUhLAJPhw$#mIGZx^*CTQuSx z3Nx7e$qYQCJ?=q6Du9@N7wnYM#ssUMWjc zzK%%={YuuiZTgp1eW4u(neZ$Q&4u!Cg*7=soM#b5F(AUh{%7*agLS4tJm~^};%t#P zXKqXV#8&YLJ^0z#*3#M%Nnhd+j;}s+ld1M%V2;9bLm~DNb_(`3U17HX&DV8P%EbuM zP{ZOmpjH^ypMzBzxk|?ZB80*B9w3;_&$nkEPa+wf*^ENy@v-#pt$#_6326SBPkixG z46u{W2U)!qpv0)2MEBBe`}#uQ#+lg@NbfgZ`q4=TKJdzx9L99j&s&yMOYGDceOG+S zj55ltgek2PvWjG8#BH07a4mP{lsNyh`3W6oaCX~k#y#*25lkO$@ct29T+ORak-%D< z)T)Hf(f|R*4r&w?==w3&yQeR?#wapUo;%dozuY2k94TpUPZk!Qw>1x~7qi`8 z3+*4dyX{8|PaBR>V0dzjemA|C5<#PD<#oH{gfx4zQ+VFh#76uB%#;&+Xp!? zFtb=PU}oPr4;=@KDsmw*pi&cC`sa&d{kFqryZM$JYS<81#Yx7P!Kq`Rm%u&N)B4t2 zixB-6X9RI0xjwITG&<|;6`P1{BFQBDMO(J+H#xxca2|8^>utS9&4qTUSf?e=yK*MN zp_q~I>YU0^;u%XxbMB}x98UdU8nk+T1rM3?tdMLB^v$<==-eP;TPxeWzr_fjPZK7ZH>RMZ8|*$F70Zhjwuj=srBl$3h)RKCTT zJ|3Fp#roK}ARj*B>{*nk0cPnNNe-~RtXP{&cCbu{F+@LCZ900{kXyVk2F& z@UKtGdIBUwj5P#|yzyrqNmytUG}v&Xi4OCwC8Fvb!Xz}LErH8|i}ms94&Hy3Sy0iI zpCQ8SL39#dYJj0(qv$FOPt;_}>SmMtMzU$?<4pQrr7mE(v_ZS%>oc~mc&b7&yV={G zC;GQc^zED6%6A-w!bU+bNvoW;W)6#M0cp3oh@NB z(c}IJlHG{i<#%qI>#DLIupo~%ofqjBg?24ekxJL(at~@izJ{+t*nrsB|A6`wn=eHl zQX4%b1)fUEMX1)!vh@>~BlTtst18V#!*&?h|0`%Kp7pmMd85wP!0truNNAw)ga#R- zS}?Ej^?6w8?ypqtpE3EYk5w^ID(^!ZrwNGHc%>o~o81kPcTExsey&;jdPfiB2#(h8 ztC^rotA&-rDcaq)O1Sf4ldCd3G({aj*DO#k4VUgdFPvSy_m)qreM)H2j(i@}AZ_TM zV0Hy}J+5V*tSUH>Qu2w~z@@x^!}LT;Iwc)0!yNN#Re!4?X{HA6?^k5ho~r))r>(bbiJRVSkFs?*zPfqn<-X&&s5mD_t`|k$u8b0h)kTG{=h;X~eDB?p zBis)b*k2OV`0bc3YCAJwrGVs;?a{!(0I2+p9yZy!vLvUzCQ*{7tvvjd|C-*Tj^fxtr!NRw%ZA(ze#j?&*mYJpB$4^9W zB)94ky+U4O(>Ldm+os6VmAD93DS5ZTZWox2OvsS^BWEKGSsRio8K zleD6H&4$AlmQiYxrhT3-84_WPtBSV z__LpBIWq}K4#4vZkBhgY|0eHsCHQ%#43l|AkgK8}7hocv^sza`%4Abg^=7MN7b)Uc z`mkW{4c%l(>ncoKRwa}J6U!5ZK0+oDg;!t`De@ALH-UKpstR_&|)GbMFpv+0)TYAKYS zFzr)mNe%ES1`jj#C>lo!3Yg;*-#M8iy%Z~%pdTi)DLC!@haz27pd=QF0Nw#Wu)tQ?#qA<6|CJeJMwlAkA7=0 z++_cm6ZwP_sG<=QE}Edm9+c4Ey-KM^K_n!ZmWLc(Sj-ohbyz1s%;NkgU<#ZNjYbSH zvh^wwdB1|2AL;2?;Z+GJ99r_R)fjtcUMZHo0n2)>><~AO{>RR<_RbTWt-$UrOZ?9P zb(gTXKn|7;1Nf(^=fS$&CGues%UsgQ{dHgMbAHS${%Lc>i98ria3mTf4p+1+ zsjcmY!dU$(+hzz#Z{8V5ivSgKHc|NlZpTsXh}1RGVNo+wus@CLXM1F8BZu&|MX2^{ zp|?|_m5I~?q%!YVmV*U{4)o($H`u7 z5Xsx&k0PFv2fmX3R(_h@rZAgS7*aa$U-KGnS+aTe{izc9@E;nR`EPMh1)6oug&2z? zqUH^{a^Q5eTFqj^3@L_5d>T2^kk+P%&!!MlGoZQ0d_GU;+2{koPVe9A=kxsgw41ep2lkB z`GU4DrPN~v*2Q;BsCPqVpzhRX_%O;4Mdj=S{kHUpmjrP+zUdzmgjwh3X6|$)xG-7Q z1&Yy<9W{5f^j2@3DcSYR1&p2$HTQ+9>&2G!&n!i`XexDn!K$>g4Iejl#s+OEc;vaZJwG|68Acuzn!lzXgt!pjB%lo zU-ox3DV!D`EbeK1aCEP^pjPU(JtD1AX||C+Ip$wp-8oPeF`z;Ay1nuVXMbz4_^b|# z14V5L@q)TBRBs$1r{YooMjB;UabdFCfQv`3Gq5%N2u%e{CEOiw6Lho8yi*kYEF5IO z^bN35FvjEm+!asUMJDi2FM!)E2+6JC!3YbNOgK?ar|e86=_rcG;zj}US_hz=4Mo<5QDwQ)afEiSL$rXli) zBr03NkRn&-+X?3Ks5tGwZ$v_iWiHgVW~O=`vcA*dlN5C6Vul|*l;c^31xpF>9Vy-C zoYO$2(wWtB4}S~6rB^UF?EAgwPO!Z}wi&I-jEb!F@3fX4O*>N=i^`Db%6r9Va6RJn z-P8y}QM%tyiXpe91KtL_xDlRUdnXBgqwVLti+2_Yg`>(4eq;@5O!Q*~h|Qj>MMzS= z;7#5cTc;gtKwReec=X3%8o9Sw!mFo*4ulm1c$30;W|f!{|M^**8j4{GCa#0Bc~&H- za;9qfKg{>J2p`}B_=5i&DFW$8iqah6v_*ayp`&3=DQ zE9C2(&$q#>WJK?$-?%oP*XCGFr9%UOtHEnHW+9bKCY&PEB(OvOi*x4h@gnOfZ>0+x z7Z|(+zy5Q>q;mPK&?SzXq?J8$tqQPfaS1M$n z-(JTqgphp1Os%9p#_X=Q$mj_enOgF-w7;3Mo*qOjA}2n(oF?ZQ&9KyzD<-%hDM8|Z zJPO76=OhZW`zj-%19Ys9*mm$2|Rk{eu>8b5n5y;Rl{f1`^1yU!4i zeD{h`TCbFrLKIX#PcA-xCs!vQ11%Y@aozBEYZcBQ)6zi9j2Or?B|6*1eIT1Gh5Ju_ z(Bu2dht>=)O+5Vz)1T)6Q-dKhk-cC>D4Yc^-P zm3RCO{>EapWA^2^-PW}LC(*HUb+^1IJMDVwZa2Zf3>0d z8~3V-Ol1IsRtHyB(6F0vrJ1~A{l=&hsx-EbBVW&rd-yL~PLFS9cCS37Zn=MwtR(Gs zZbxV+`))RreXn*oRZlm&m1@|twR?eylG8~EggtL+ZFvx4K220?{a`$#H@(kNyD66x?>kamGv{$xg5;>iYlQ&tslDwNl9sg)r`c)d&S zINY(I--Ds|;g1x`gW~Hm8b315$?D{&^QCDQ%)`voc5-%uMrrNL5s$+9Tf2VlT#GtYFuXCd3kG@mPFsA=^fwz#4X38Ic z&0obC^~bYWE%fen-?B52kNQ%lTF*yRWCl zFq(3PyP&&5RF(I|uC-z5H4hrcEu2b6R&lqeSJ~#qI&m3|%O_5a@a; z<7Bsde|N5(uX^ZyN=5IzdUc%9(drNAt0sToBYs>>X0-p}H;(wOHY2t0q37KGEMUlC zCrtioB2nwN(au4uB+7Lv4V! zm45HFOeiuFb$)XPFu$6m?>-WnXY>KTN{wm(rCO@#frq!MBVqKAc3a8|bv6p9f#y2B zJAC>-w0he~e$zRZ7sz-h{_#Y7%7<;&+QB7|&AX+}tGc#4f7_reE6=LyGL~os@l_+; z*~;e;ou$p|wV8urdJ#9Rhk&m8&4$CnX!%wPkAkj?GkP#%?XTdj0t~ZjUo?2kEr#0+^suW`kWzmefT5(X*#qzx9nWQ}PV(<>)a3fy zibAXIx)^V|xDLNg;qkLBKjbmmyI#+S`LAz4^GRquXvFM66h_V~UoVQd`mX8`NjGY@DG2_|g_Da>3PS@is((`g*^FKj^)@ zm*whpf>HBiTB&pJrYgd`Qa8(%zud0Oj779OJ3k_^=}bCw zPG~1>sIE&W?!VA_p*#s41A^@KG6hy}B3k#_D}~?Fi1#&|*9GNuEoMyx0N08B^^V~r z&}Xz|HZbbAePzY=e6JysWsmLq2SOY?z@mJgP_{R?)RNCwFhgEHS|-r&R~e?W^Q}Ea zU%E|y8wf8v-I~Xu_tIKNV3)2Lt9vi2?SUlnVn@m?pI&cnIqz9rT*Xvd^d z!}QN1{iqDN+qpiu#-p+qF!At}*1W4skMolD%o=! zo`m_#QP0w?Z!b&~sng{MB`$d;w8CI7oE~y)wLDDS5`+G>6&G2%x;@L`PHKLVX2oQP zCUUD4blN=QsvEnx_0NK{Ueb7X0d4GBAjUQqzC5`ZtsEo2l!KlfckG583f<5CXmGul4Ih1~gHs!$BFeWBqY`Ixn4jPqu z;J&{zf2tyxO8{I@lR&od6_@&tC4#-9s0W@&we(hBKbdFD%BzOqnRB z+?e5OEcAPJOyI)Sc1Gx)O>uNT>&#ugw8gAUngt>4^oHPWq>d79+t>MGSXY>%OEDm7S!a4?(pw^cqBgBw5AyAdN7S)!K zdBB7wPHO-vyS(a4kBooQr$T(0Rm>^b&?*pn4~E{(z8Q&E&{W?Z58s*TceGz(g7~A$ zLIYcrY4~W{>^C!?>)$KE3+pIii;M18dC?rH*v~$?2O%*ug=LOj3B9hT-(Io(#`FA% zr0UKFRU-S~PG-^d)j1$fQEb^^zKt>a4`JWsfgDaXFyX8^y#9`>FKsmwQSD+#VIFTW z{)-0w%`w04mz!PC#yGYzNYum!xE?VBdJvBQiMApSn4<)L`iDV%F)mZ=b@Jc3hZ)E3 zA0L6-?i}P`mL>{D{_vG*22rhLUne;geMD+7f^QFk4g#)iIiWSF_{sJ&qy+o7X@l_+JA@WwGrP5=Kd<*A25zW$)VoLn1$Lf@P}+p;|a{5sA8+NEEXRVwI|A z7y|{68;>)CiW2r`eNNA4JC0~?>j@#kd7T~0CFHXJr!_2-=;hMd0ODail@@a_4%lvo2GCO|7o1N-X9p};m0r>B+=MY4QSatV8xxvpfd%9m{7Iz z;RGwO{fZHS#PzU$|;O%M%JLaBU~69V+oStL??gHab53 zx16PDYquWnN5s2jVoAc{G=bmd%F_@lljMhSh0~D|4~N64&>%}7N&;!(3ZRj=i}NN1*5hQ zv%FY_jS5yEJ|a{9_KWiya{@?}PnQ1KV@c{}T%pSJn#d^Y6hHoEx5bsOyfD$0QB`K- z6VigA03Z&Yim`cS(FbG%B)Tc^ZwxFJ^;XZYd=s7^y)Ern(C60vV;Zz_9<=gHHRvDT ztiNl`n^2WTwe&M0hTEcp2Fg8eY0wMSg%>Nyd&%?D2#?o{6TA#F90zO`6VaEk^dEw) zed?lbd1)QB`in5+Q=FC)I5xF|h=%0@>VL~}vUm?8ioWp1$t-PYiEHw8cg*$jHD2u> z-0i=oBI+Z^Z?BtDZr0M|Egp;*auMhTM<{yLP~cfbyQO9Nd0SY%_A>)YF!FKmI9DIa zDLm=1;yX`RU)t1BP8#?0jx)c_==41dbhhQm-w-LwyxPD2p9}9G+`Ltok$B5gMQZe; zo4QoaxE2Y^gU8$l)#_tYPNcdKp>p!kw(|wilbM%d3l~M{!0uHK=D#qj_SQ zIFg428B?8RzW6TdAvjr21KQC6r8~+7-^GcS>8ChOi&v0Cb5dGQX2tZz#eZ~>a$LE~ z`X{{Z3irzC*PAs5c9dKwL!b~#|~p;!-2c>+~A zQbz7(&duR1^gn&AewqHOFv?wj?HNBTZG$`hQtU{f&%Fp_1uLb?Cq8|~%W8B8n)fKC zU2^c(MlL4yL|GiETwFicVr6YiRx%5>4P8Y0! zgM>d%*4R%+I?1j^MBpL-qaa)#ZB9nqt>{~|Pp8k^q zx&-n?^2AhX7W$3vlZ-X7$J;~b%w^ql5p)9dj(Jwe)nWQG;>jL#@l@Vl=D>cj7%dEV zYc*urYvG#Yv^$GjHR-^5r_)LM=BAs4A(x5Fsnn+idJ-`d+S#?5< zrcWFB%hu11<;^2Q$-QJE1NudiaI_wRlUnk-1Jrob*ewZGYYanO6g|llfcXPrWixKuskdvyS@Nu8b`zFP6UF90Z z>*ah&g$%*;1o8eSx`8B-Pm_ho-RoL~0T3ifVJdI{F-s-moVZyV!682t1&gfq`#wRu zQ;gDC6|I=F%264n8n-;m{)lIm%8RX(-)6LOnyZ-c9NWV&t@8oU3eZ z9{Rx_hH)oNjs9i@bB?E1;q}xFhK_%70vRdJJdzr9V|8MUe*`+_IZ#5~^+F+P4zu5a zyt}@OGI2>L+2cT-`7;xP^HmO~CpPh`G^mPNuo09HIsGGvH91j%2BFod zU0Pxf?vzP3T2|DmgfmZ3{RB1*XO2?E3KIJh`u%M}^Ch2%CV?`~vw&K~{;xzf)t`RK zv}m2hm@<=p^G+Kr@8_TlgNq8(mqUA#d|ulXhZXNhm}@XvQ`xqMd_#ePH2Ls02=3X6 z;f=6FC&?sBmkE11I-!>V36xXbe=zJJW(nnO}CeK6fTeMe4M3G5k<Jkja2z77F-O)Fl1wLju+{y`xg4)prj;(IplR*)BF#7HkmQISDPEc2> z!i9~)s)7(?L$4@L^Cc}KGWx4TI>ELS#fL3U-$d_iS5XPUe4E${Tow=xBQ}qgWvW!y zLNa{=aFlE5BJf4}?_Za`=O;<5yJO)H-IV#K(Y%H>kBwg9FoG-7AvkjuMLoI5u1Z{XKNYvuT=spG zgvGuz(g?_xkmR8EzzvpjQhp^4nrq4Xvi*fL^n#@o`qNEF+LO&W%EeSsjs4&op5;et zdK8OcPt2{0JG@zb?KxFMX~|_&jv?>ih@r0SBtOBeg3&(rzOQ?zhZyOST$)RWf8o|R ztJIHXvzUxQ=`xeEbd&uaiRPV5`0NwZL1W~hgB=ZuI!@n7Pa)q48VZ|)F@8U=mZv3v z=vzch4RL#e9%&@TQ@S4V{KBSVAcQUBl;yFT&Sd_<>i|0qif;;9+ay+I^6G7Na^#t2 zI)HmDKdtqF{Hdg_pCb2oQkmy4ePghbg4!UmnR3+398eS4Xj7CtW|Q+Cla zW46ghKMW1;jfA;sJ zuV1?~mO<}i7x@?tWc!r@r_7*Z?}drLZ+~DRiY>n9IqnFD5XyW+u1(R(^MK=UO=Q+2 zLtJWlWly{PvlMmlqjPM~B7<`kyvyk2eGA6Wy(2H0nPul0!8==xtWNCK_$z6TTXhL+ z^iI02DOQ0wkEKnLPw<70>x6=l6*!14!rmMXoH4$|V(qM`+fSY`63Ml_=-6(xvAOow zb7F7vrjGQbUZ>9-_@3@RYI~8=x1@X@WzZHtu5BKT#>W*r(^vb8+$Q}FUVH>Rfb#B! zzsc~uWS%R0awe1bs2hN>GZ37a2>3xXWlTUPA1RAFr*+dbLC=}L6Kjls3@1_IFY}6v zT{?*dZmajxdXpaIhp2E~#*7i(jX9jC-^bD<7K6VKKMYXwtl_Q`4if<6Kegn{@|Mak zb-7Y8eS@mzvTgK*jGRPJYBhrwJR53%j%d^numFx)0cf6hJvSE|gm)T#EBMxwZ$R8p zux}|m(4quVCB8t2`*E%%{u2M&8|D&mO=-(2f3!~@X_dZu?3YpB{SGaYl=Qzz3Xu70 z(SDQi>rLr4;!xnOpg93{YpJbwGLDW~6bwKH##K$68y;p&TJE7V`*bh$Z+n~ssVcU| ziLJMz7od<4lyj(<={pr52bwtHm`sP@n?p&P7@_z8RUnWhN;Vn;KgSYe)yg~_g;CeW1_%YZJw6v!N%;36eh06Fetw0lWvZoF34sOP496esk~#Zus2 z6Mo@)$@GP&(0QW>ukL!vWZqKXu=8b}0-7|pYcd(Z-O#@+bmYGjk|j|;*sC~rql zY>i<*zI7jONynEBM>Jm}M5yEZECn5uBm;zR1(Q!}gxzGHGo`vsl3SRU@o?x&H-w@5 z$T61JL-2XD(6ZlB2hNS|B-6*4Sf?hEoK?NQp_!C{r8Rvc`R58()dBw_v}Q%9FkyFMWbkXYpGq$vPEjbMBZ zC{1(pd<(Uet=`YK9sVl-%2Zvt%gU355I)c-Si>mFlHBoXr-fJo4o@s(cIbhF!24a( zXE`nAQzBV4_2*zg1_LK@QK1+rqOhiZJHab`qi zCesxIBn*3ebam?wI&v*obw0pxj6-T|cZ}?_=WXL)%a)?1>;8K zJF6+pNiky%j`n#E-Z3FrdU zniK%0b55ZQ2n7s>2^It8*K7d!^4mO6lUXg{ilA+q=sUcoN7@~<5*fO8_y{(EJ7 z@0%ud^=;@|M9yLV`juUG8ILNn@-b8SNdYLtae;1vFe{V}ECtCx&c#|-7_%xhyXT-wazon`-mzE6QJ1R^0 zV4^mFIu$ZY6tP;fFp)GC zgxb}qWg3Ai-NUt07gFx%picECxu#F_fVJUV7*D4raTi7e9mWiTRdmhepT@pY!VIq9 z*<5v{2@UA^np7ES+LZ)J8q=KTyQw@7+(Pv%LYm9w&~_IN1thR|Ryb=X#J;{;*XHCp z&?v)YL>;3y`|n+K7u;A)fT9Iqs#gh=#55s2GhC^d$^)u}Dsb>(Ar6$<07=uky_<}~ zkY9vH+(Dm=enn`5-{B8So~d)y2T~}4j&{`V#83|%C8GoncylwPj>Y`@hfV!WtS=Id zv(6GejNb%+*!YlE?ME*nL$#{b!$GZbz^cA9xm4jTD$R2hMyvpmdwv8v+MZ~zcLE`e zkIh?P*>QkJT-<5+{yT=ScK5q5p*!XFRLbm!DZaI8&Z3!jWaSVfPt_y>!&s`FUk-nS zc_71Z;{Lwd#E)|Hq()+z+qH=;@*ikifGp~s96Q_XK>WB(wtnwzDkM3GU)xuB6{!A( z?6fRPR1amFFAP#n_#1Wc{eS}L;rku+X;A5WV9~0y*>`-*LWL;~)&Q~i7}N!%UmU^e zMFF4cnBoO~orPU@X#}zU)Hs%22nbLnR;3Gm|1+v)o)E2Y;+I(}E>19mn3T@7?#0)u z16@U(T%ln2rI8;apO-+K)Ii0zR@7xdeq~cB03rm#N#H{u(dDW<)k`hKOf}fi_;7S%AL?o>O@(zCv-MUYtlAT-ny9S;N$36G=uTP_L z5##GW1n|L2AxI9^hfuQ58tpmUHR5g#_7?P33D#*5@Zkh6MR4Biz483 zX{RY16~oL&f1sK1QUonB+u5S~-G`#BCSwWV2l^qQ0d-Z(?4-HZ4W9>oUft``$?heb zNVW;vb_Mfm^ZX8`18X1DAnQS!z!8w7&WyGT78=%`Qr-8VdxvAc#Mi=WE&MN=w|#|Z zBE7#;9~<3!=ksf8#+V6fyTz)D567O*vADXqqejmxa0TiN&U*1qXd3Xz7=dP5TqY}V z_T_&;6@RwHkwe}R02LExS{6?Xf(3#s@L!JKzoHb}8_*Fa)p=vfm5u%fD)RfJ;2Pr( z?xt=?dWC|)Rnp1FCek5(Ww1;sGAoV5ss`8D0E?~nlq{=*%^g> zMTh_Jx-PuPxJ%fxpHOc!JJCS;THg8Jl4E>MBzp_RGKl%uBG8V;9Tr}9{M4}yTiEJMSJm;9Mm(M?H%ju;l-pu>2w?J$lAKc2 zWcW1)@OW|G)EIfKoy3DV2Oh)CD3J#=8c7M0-t|d-*y3it;KvV(W*={WqMy5o_N?Wr zQU4(TNgwJY(3&A21gBfsgz0hzsZIu7GgIs`s5$uEOW{~nYNPZhI^BHhPF=h*XPKu` zBaaA&D@J6&Fa*=)CHgI`p_tk_lHtP#^!ByzBx8yOQrYf6Sxm${HM+xX0I2Bp#9ug? z92~bw2K3O88|v02>J*pq4VjTVEAC*CdKP4O%N1Q@W%5nR8R;Q89ZTnqXS{BXRMhvT zF`XcBg*16-hdZxTkB-q8FkIl*K8Xg<4`D|_ld^NYAm2NdSV_i@ zVlJa+m|DZn5RuIU>ZICuRc-cUL|X%|PEI-->%kHz_j1=Fq&L6CBFFJRiu>mQxKc#6 z%*g=)Gt%Cy%~0fvM!B?^Lzj+v5jpW&{mO}uy5CX<%z zJnwZ<`60VRo(7iDFe|hLZRyA)-w;7b>!r2&s7x|R@wrDo8LLJtPU}D>U#rWX4f4+! z_HYVKXZ3w`&SZ&ySk3^)@{mDX&nQBu*ZNG}MUH*j{!54zz#MJriRSpbx1x?V|B`@W zJo-DP!(G@=zSqJ#F2)N&5?+b{-1GiRUWaQ60rnKe(0Be*@xKYYlE34V8`Z5$Et$35 zL85DsRis{dg}C>UX1Z|>s=Qn_T<+ATwITjS>`SzYlk?>A?a#3if#fK1hsw$&yBEH; zlIo^C<~!uAuPvZ|)nR5S=K6s(&wNrDoOD@yEUt$}$>G_M)Gk*_(P8_PV~MIBhoumJ zq|EI@xpD~Ei|3*CP%pASJ5YvHRi4x=3hqs%?&zL#KYP*$MM#ff48jQq^A0QBrcv_i znkZb4g^7@AEO^j!wT8EsvTJ(^H)3X~?4|0xXr-x^7pO0fD2qd~6PyXTM^~r)_lO)T zncm85OfBpM`-&Ck*}J@l6Y9aLrRV;z4kFz)!L((s5YQfWB}w^`9$XnRKh@-3REvs* zO4c&)Ff6;UXnHuY*j$Qd^uLxQ!DP`B|MHEiR&4qf?C7wB9ZaJ5_GYLFBjeC(b{AoV5{mK0V=3 z)1ctv!c(gwidyE8DiNSlt&if)!2gal6rjk=c=<0j#axJD{7W zM)uXVY7Bhvv3MsN#RrtkXtWFO)2JEDaiDQEnvl#IuIy;y6$Tsll_zla9~}9VgcAGU ztQMD;#{DH|o!d{K#bWsPMNq}^o`|UZX8l~uN6t@~SC5HX>|_iU*FQh;_*cmZGKQ`< zp>mk^ayVuF($%Ku2GOcAu>MR}>;88zVG87b^`N~$^WJhMA1GZsqu}6d+KI1o&0jcq ziJ|vdeg!c}XXICo_o3P@l^k3AEm}rVLE8kn*~9E(gQ%Auc&$pYh4PgjKzEG3#{3%V ziElc&P^{8xqW3^z9oKb1>hs91M{9dYCYdcW9_PdS>Ud`1KAACi#i~*CyFDL%A=l3- z^t|8FcLq&yZTIL;`p)uvh&=cTci?*Hzt;t%fasFW_PK% zfrT#=8B`~!O-@#?Ob=CaRokUQpdFwYYd zI6W6HShDzprZ&WoYSiXNy!+5tRC6Dlf2mf)PODyqE#70U=FRWIoZR{{|W1=57Y$-O4+PsE}D)rS4*xn0W=+22*RhYY`uaz1-*EZ!-ZDNCNDL(Ij| zg58jnkpf4nI0Vte5uI;u;NdGCZXBN40dX0I=3?jZ|9sv36&NZmJ+wm2YmJwer=p z+SssZ=1yraacuTZDW}a~3~Uf#1QgVKmYSX04Bg_ZmTVd8OdfYF6O)#_k;Bm=pCq ztqkz1GBAIQFUu*}Sm?vvwflnC@xj+B0?7p5tPE4vklVd=e5cO_GaeI|V}tP>E)~W( zB%y^jCI~OD+xCF=8r|6M@!||*h6dpk={Z?{QtQ&fQwzi;mZBOof{E5}eF>xDkxX5~ zu=h;;Y*U$F{?!)g953Y?ulh+tk{fm%-2^ z#cVG;iVAp=lApYyn#Ep1hnrRkN=HCJsWselKU?t$WGyo?OI!d;2x7s>rqR#gAq=v? zM7|t%QfVl-!G$h8oglwrsvD-=_x8+dhPG-+9%ff?#P4^t6Mfxp5~j9(1XgIh9PD>L z)0hY7B_U&xkrnU?+m2!jnCad}`y&h@2k;bf$}j$znpt~A+%+d-jB@*0hn{m^pD`$5 zjh|U8H6VeSlxJ)yUmk2PKGSv6w1N}d&{^uZ$nEsS1dzrgdYhXRFFzSld#HZzQ)X%H zFUnKTCQ% zr`Yx(1p;O9z|YNIl(g57rWsOhtP@sHH%vxFN9>Ga z{I6>w|*Z?WfZ$pNSThZ@xVRc{Lk=>NPH;r?IM zq5uDU0RP9o8D0M}=Bgh@sw)89*_(>&ht>Tm7&Wr;`Nrs_!HD;NV^Q~kPDy~WHMa+* zJDS^N^Zz$zaORRS2g@JN!H{)xX#mi04%R^2?+`ffKd)?4er*uo%|9dAVfU+$}4be|?) z9w+6ZeJ>-8^nC%HlF)F?T8R6Xn{p$ACt#pAM~*MIj@mGGrI!E8oxdpYJ(>Xob)vU}6(WD)Vk*`j_xP$|8EU1jN&n}or+A5Q7==;)Y0vT5LbRO& zIP}@{ro8LEyc{_7%SmxGNp}B~TCzvQ-#&Vg8t*~D^_u%8|A!_&o51Is|4@VAn4TZH zUVLsW{ckPl^EU%-{&s9Ubj-;B(po=Hy}vMdBcN&&nS0H)mGkm1hn_?~5DX0QdG|0h zR20CSRzKhQzcl^#o;!W#$SlX|c6#^~DpKzaATv8#0o6-OD~y=t4WKgr!(QbkhUJ@#I5ETY<0ysZ55y887)ho6FvPP>2-oc_A~ar4a6y=+P$H?(r@3xuQ+F!XbAO>xrWog@HmR_ zfmogG4nUb@s{pjR`xO)b3|3j_5vFP?hgrsX%Ky-R$*h1mhtYo5ah`=KD>^X;AT$T& zU}C&ESOe?(M%g zn&?eNW#@=Fdk07cO&Zn$yTiOxx(6QA40(a=X_`$*`Wj}P)vRAOF;q;)0+MR}Jj{qb z%`kqhawJ>BmZm{bOnm$v0J;8+BIQ<_twA7ykzr6h; z>#lu2EAu6SboDnf{_7%@nN4u$LI}vA8AT{73Ivz?BzJ3um`mZzNw*y=R34@pb3~g!d zY14uy5tKV9t0wO83)3s4d29lvy7M~i3pSG(W8V))M@_(uAU$4 zCJ)C3NgGw1nA4||2!5D~@pBaw0&6dyJO_nL`rX!*w{1L~Y2>haw2UgHy;Huo2CQQa ziOsVofw=uZo=&E~AFbMJ+_vas9S~R_4x<7(JtHCFj&)}keAH}!(UNrS8JNwFa?y?` zWJZ(zBO2M$D&$<(gLNCXjZIlQ%_6xgABxZ`zdQg_JJICyiHsk7ff%}548S2DO40oP z(ED?O_hG<6EnqtMY5{(5R6FL&4oxwZ1ixU`>;K4)Nh9pu_}^&z^GrR!Si^vWXnB|O zU(B=<2$_TX$C=Q(YfX6_O$`M#FL%_ zQv1AjYrupS(J9^N0Ggm5W>$^}@*89BwOpGDygnmnS=&O;idm#8??h@&%EWRc z0TQ%-|93q@Z`h_j%ZJ0aG?T4m=u1dT%yFwN4T=v@fL?!1qiqAQryGrgO+UbK`?>+@ z-4`?Dd6-2Xc6OR{0~+fKMvS6&j%)y*;3y*#|T)E_87+5Wz zf9|!TJow+|H8xiWq~TSe+ekyjVTp2LJc)F``3afiDj>wkz_c7t8hraS^;Su+iX&0r z3zuwHqPb?@q9%1c_RAl(m)iR$0Dx#6^=n!sEiVdZFY}x6M9b$e4)G*z&VtGsoe>h? zy!COA-l6vShAJL|JMxW7a*e|^Lt^IxIo>i0<3v5@MMvjIJuf4_$WCSfcw{(CIOK5h z{i8fz>3>>k4^Qwk-_%^9A(^J`Z~~!{D@pglzu{h|&TBsUpRA!IGE4otb4J+1L?*yQ zl~zP$2!yx5p@~2*34@FK&L9_WtcAl)iHu1Z1}BAC@(!1sG7vp4l~8pIJm@!B{J7BB z;MPWeI(BlhzV9=U+2P?WuHxZ#bLFG%by7RZS~@_5krtl`U*z?H&2yhP60^uK=4QdD zOY^VUa$nK`^l3)y0_G3#$PG@r+WH8^il>5nhggQ6G%b0p2Dy~zJvUdD)9@W`Jm!JU zco}{I&*Ps)|sTqHiOQC!<`mm^lRsVhH3`v_6sB)X*U#Nw*epbvH>Z+)+)Zi*5o z236k>(p*S7XzAHET>A!rMwV+H>>`oa=;p`o0r7e|&=M}!p@ayDcoOo8f086hwgJ-N ztSVc*VwE2?-M|2^%}0dW;yT`v-74ZwFuM(|vX7$<9BxFw=KsFVje={`RYascu$rvr zkY@7=&8h43)Y0F^dqd>mgzHY=Z&qHkZDoU&w~<^}IVEbKF(dw#4JxR?qje@LzQ*bP zuy}HXPQYm^oSx3HQwnI_#83)RvGYr+jN+E1zSbAwCs`#@j?3}YR-tU}d0>&Y z{P7{Sg(?qP2VwCZoP1uSt!l3(Ld;_fepUU2ZT6d8+i7E|;kWN5+{e^9)KS-|rqOA4 z`r`{-&o|e;(X?tbni0-6U8t+mbvaQw)a4)VG;+e139J`+Obh~@q zii>KF7u$$j&0JT(U@|5(gVKsB+65ERzP78p`~u}{^_zht-V(vWD}x1H5=e~Sb^c5v z?B`Wu694``-LcuN_Y3yP3q{8TO0#LCEYHlf+8B#C+G03}J?eCQb2#kLpo-})u4st4 z-GY^k-#nJPctFOP92Fq8K?)x`?;Q8jhx0CfPO^IN#9;_E$$Z*p&qK^&i4^y(Rg;$ z%~Yr+Id}|4KO*UtGp%T;BJ1<;)*qFelhugHoYGKu{(;@JTk>#XH7NjCU;UQ!0 zLyn6cWo=_di*c@?3So{uN`4*~2qASy?W?x!zt4<7A_#^FCTH&E zm8_v2IzDv|j+lW!&WPZo=@SHo_2nA{;$`SmO={h$RnwU7DliqHP)^)l2}DtRVd z%ttEXk{Y~@qaQ3CU!1HDFJETA@{!_Ro-SJMBi#N4E~9x*jfy$2+|1Z z$LM;b;N*&2Jzh=b+_gW|Q|J@tmh_tqc)dwWY_-`KFxMw#G8k;_ObemjoN6O@*yPkT?(Tnmr)j`tV==T;Re^wCC{zq9r}_nECnd(@`5dSgFQpN} zVKgsFK-d*E0N2j?@Kn;=efnv-%}~eHc2rQx(G~Co7z*8+I%NV@K(g7 z3YK7_qP{oTsE=TaTmFQ>yWyz6LyLv#6DUE@q>LgF-8`3*<%~6lCE&f;kVj93Z;xT~ z8r?)r5!>aA)2|vr`5udUV6h@YntvYAwk12w*X^lE3YBX@Geh>~%c&>-QAz%^vwgYw zXk0%pO2FhKV}aYUF=LgaeaArhS6Un9IfePdsFpqXU=f@D+`nDPuyT0ZVqR08{Id|a zp0MEP07NxjGLRgU-CqGgrYy$dwPZ3kmBBJWY6^0!09hOZyC?CR>4`5yn>3ZhAtiO9 zw}QPAl~u(oQc6*g3n@3^u+{H0#9-}CS>Ujser>~$JsB0tjkAM-`fbCeb{<94l++J1 z=ea#KdWH@M9OwQISDA^8k3uGj!8UG&e1xe4)Dgj^8Djd|k8*=2Ht&Aoc-Uoj;cfLh z9!Su_wiaFM}*1;#}(tOe)7yfpf^=cQiWL=!4-XshO zkkwfL&tlDf`B)q!z$f`-EmfL(%GqA;?-DQWJL(dGZLrO;?{w_3vt9&~_);oFc(-Qz z?m-{Je8I`e>Y57u6aLnF)IiElBf-u0qO;lTtVbc(#L_4TS-Ej?zvL9K={N1gTAIqYmqOyYi zZhscJ>ENqO^yhQkKM!*jd)J>zYoqZws5xaQS6Gv3%QN!pKLHue^hkKQ@`@O>Y|~bC zIW#r5l;9R;D)|>6$iOY7E7f>(fG6*4-oE#(JY14-s4wni;fnkZ$fd$YTptN|G%2Lv zXuqILGMl!QcZ-x>O`E2_jk za(OoF74D?L>3o#LX$SEVqVVRt`B?Iftc}jGxTmM6$t>GA2kC(QjXq)&WT1K7Yd>qB zT8nmoo?e`ksMH=`yD!)RpJU_i8x0a)cB6`D#VQ4n zcZt$wv045V)^B^9*e7OujW1wRzI^dy|3IY_-%}kPHhG~Z^{pOM{g$YYb)`T>PyZ&b zGI{*w=zek`?W%t%UJywjr7!tdO-8JmD<-ur?Dgwtv@F7>B=n_Q;pp%ZJm-zb#=Azf z$nx~|3#Y7k+m|_kdB!N>E*Ud{K201%u^$9z&Nu!zABFm%8e1}EGWLWBd$29;y+>AF zFDi9GyGj%ijT|D?f+mhZORWAq23w;Q5rW(uM`gBGwpZkwMPgo6a4ut(Y+39*>HziS zY0)s_dXyvI}&}xOZ8RJq7>1@~(n4qnzKHM^G`0hoh_qnam=YhMA0gk2fr8EPH$V zZE8-%YIf0)QN}*2--4V`2?`mEpzk{49{ObZr99wDAjC?0=BrKsYRZvR_v^rghbZ-r zg{gbUOw-ei>)ywN>5T_pWlCPhm!alJ9Yl)APLY{0TNKUxLwOSF$gfq2=b`|J7gc%yGTS`Vw-_AXBkrAel>i?&UyiMS}4vprF=aGB>MPk)Z z(m}~3)m&Wc-jEYH311bTb*FvaFvw|K-gcMlH^|Dz0I&+2F;1$gWU$bxTVLf{A|K7J z8Um9(pj2&Ksa`cWKiHWqUV@)Xb*>f&s_!27UvFrAY-b*^0VG@)vPX)p)BiC(Z7%Va{AlR1^v`ld9|Ef6_sjE_4vevFN_gSm@sz)+ zMgzEcRmwYhgQ5+oePab}KFENx75k+O+k9D2j2m)m$K5W?ceY`z-ASqxbeHBQy31t3 z9I_aWVg$@qn9)#~ahXb8f_}TTVwN$^xEtD-OS;0wLYeyRfBXyVchrFo|Jl#|=l4Jk zoc}y4VgC;_#td#o|IhKjpCE+)Z*QUPhlP1y{l`?UmKQo3!6!d@)9WS#%Krg-GR-dL z|D1^uwFtN_W4VjFu9pt*@2?2yu)@^wY@f3wGUqLqn%jcKX~Tj;C(wm$T-g8l-#(v;=TCfU7ee4}z^n!4Pt*2iHB{`YfW;8+ zekNq1egEaQsf!2b!p62I{N{g-V^&&p*EaMre{!WaWNNccMs+3}6N>SF+|G#NT`u_t zj4tc>1&FNC8^}Jt04V~lCx89N`-Lut=f8uZ^D2`E*W6e3ibbESFQccZRKJ`Z z7Uvi`rh82YU5!})o8j`DY_w+phcT;y6$LTfy7o{7syVx6Rgh?Yrcoi?ZAG$G@v#UD zXuG#FBPyqx41y5oJ_$5?*tP5?Ylu3+LZs<3!bfTx^hg@pqejmzbvN^R^eETrJ!E6^ z8u0O1{d;CQ3WBHg8W7;CX>VDct7;S{n%PY;T3>y>1v=aU86->|f0cL~?VeBnfSyZf z0{@~$6=d{VtlXE(Q(E>x>aYs53{5Sa!hb*rLx89kJqO69q{I3rjl)8z-ZP4(f4%@s zn-bH%jN!F6q{om!#yD<@a@A~`hUqfK%3F>H0be-X4#f8{gDf!Hfnz z!5+*YcVDE6YWN1@aEO5ILY(J;68>{5Lx7hcy|#}Rp#9})oKI!_+7AqBx z>%c_@nmG9aNeb)64ZnvyX}(r9lHVJ^jW|w;8cWD|+}VRw(F^sS0oC#5c;egg8e0JV zjNI?<=1o!Be$d-TW5`|GhHud{v6KL)V-lDJ1MwJ7z`=zrz#ta%JvYDI6e)w+Q7Al( z|Ew#mXmn)a5bJuZxn%&h`E*kxwJY)VUKQ_4ilMi4hxNNzcS9g#uQCNxTfHdewsNp|on^KV<`2Dt)u*&}pntzzrov$DeLb5bJ%tD7%u0n%$(^Z6ui zfcy#H^gfIt;{e1LRe1KZJPbL@K;zp#qZ}LNwDeru6cGD>9sA-*-J}RGc=N8OTU&~Y zuP9KdPd$h}Ko7HKfWjYnq~r_qU!q8uMhcw5X>fTOpT=6|Z=S1kNZJc@HgRYfR&t2{ zb069R(#C)P`R$m<#2d?501eXWJtK*0;O4I%f!%Iw{RarUm+6>B@mSEM&tTjU|13?x zS0R>7X@=*pesCTn*LF-70GtHU6^z27oG7$I^;+)Gdz(O8wY&Me^P$KJRpk;^#Jv}< zz*=g=%Zcd(zEr?!83fWkZlMteiX_{c$b8LWkO5-L*$in5s;whiqX*rHk_bh_Ua@wC zew?&4Ldh`HJAl?ed{czPQPv#DO|-2M%#-SSyKGzi&UGM8SpbtPHh8T$Q(5fS>xM;j z^+8P$4#q`eSW?Ej==ly`o9uF-b-rQX$n$LyY(C*8X9YK>zq?HW!FHKPN?fn=?hp-Cn~_j zkvUS5(16g+Xv&`sX%7bsUa7&T~cN01vcoJdl(%O-VnoK`*I zWt37hyLPysYTF~JE>({{kt~g}`|`L~>_UJ|D&2U90=@r^l^{6sZRedn_pFS$8vrWl zIHlr4O0!E=2jII^76ytk9MaQ|g|2g78&JvL0AQuM7NDqp@eyd{mUj1r-A6XXrK-Ux zn{J?jp%qmNM-k!XeD6ZL<46;j)|-MG|CG~blK8Lla+KRWY7 zp!n#s|M65JTAATs7P_NLhryn1iURm=7fokcbb*NL^gaU};pU|ye%5SxW#e*L!>hJv zOmlAHi|1Z!!HV@nPgr9%C<)K!w@(K$89r;rh5)++l0i6MSz?W^-AaXPC4kleJQ(do zfX77vYQcFm(n>gec51dikgVW5|M#6#EgeYliRwrS)$-~VQIjk# zJNetWSNw18ng@laJz*#y?Jk}v-T_Y@gYw&QdnNpVB;i! z(u%E5^XIP{%C#Vfi42X_rs|Q^;C;~2z@j&d%p$b99d`dZB8)zuDtDkcodZdX4eD50 zGTpYaH>>m2bY-=)qTtO_A0h9GT^#|r7WR+Ss#EXESFgUll9Lqs2rNtnMoPBidkWoL zU6HSUFA>QR?*P5khZ+|zfOpIY9m=WL+-~^41YnA1_rY?r><+)P{!jV z9Qt`W>$W8$;AuC$?sJ|+{n3VLR5=DQ>fbk$)!X&e=q57%Wc!QEE@DXUzx^E1>Rt<$ zDkx9|bop_@^61Sn)D&J6@p0N@!N5alf0RF!F7H?!&$3jR_tSdn;OYWop+lz^Pvi`( zR>hkmta+;&67%VVhxotNwYxQf(dUJuJ?*{v3dIXY(yQ`*fO9+2MX$KD<~$M~$GPvGRx(*|EKR{Oe{ zu0y)vP{xGm)onVdN`SM#+TEtC-@!mVXFMX>J5e;n&wjZ4)ogcssRpx)0!G`PLLEWt z7a@EF=W|QUs6_!Sld0QZ(xuE$hdwyJRFSvR0yfnaV(^&)$og|x10f99EJP~RTR+~K zTKD|~5)nYRa4yu!d;3qC&&Fo$GZiPzoc!{f_(GGw{Xu?);OZv!F#oQbhfr7?iTDTr zqCQX1((n`Qoqos~#fy|Lmno$Qmc(bP(p*vqz8-b$)78iI>@4Pv4(GHE*cGuLhz)Hf ziT=WZtAo+w)xOK(m^Kh>tFDxdYKrB|e3U5#g0Z3afW*A;esHR@NxM_$?AAc0X~uGK zJZ#B$=rLQ_Bqie`RVCTgQ4M9kc=WC!ut|eu-mNpj9M~T+-R5-;zszCh$tC39++@B% zOrS3AA9YWYay>Px8~$#>fw~yKy~$>ZFFtq(O950u#N^XVc=?e0-BK|R*Kiw zyuH+GO3TqCXj!KBxtD?G(*bsjJwa?+G^b@L#g;hSbwdMH+mnAQNkOuM)O#ZYBm)H@ zB$g29#pe)8j)m}JM6PYe zJDGHw&3xi-`*jP+DTx7k)p44UOs1Ny^#uRC$Qmm1GL=?Bhacdm`NkQ+@%ETku_F86$uu8w8|T9L zF@3F~wLNE`@ zud&u%8g;VDT=&WX3qH2FtS(GHLtb=Ny+yRy!hHD$xlQq3Tp6;sJ>$EWe!4!*AI6{_ zC}?d3w1V8UO#a6CR0o7OfOcL#IW_FKc)@BK##bBNIm?jf5f||)io{xv`OiD`c`~dZ zG@an{niyX@PA%f>JZU;s{2$l!Zc1Q*?^FQ-3|LeJk#a9H2CB2=eIP5_*OOOfS&S^| zXta!(Hadz$4Zo+p7euqqZOO~w)Wfo$G0RsgkObA@ z?Y^9Ew$C_siWLr8k}}IT9Rem0cC_BY`8(3UW}~?YQqQXBR5m)qqLU;)TAV#@NQ;k| zu;T>!o}+t4`r#iCq+&SpGM4k;eiO=rJyJm&$GrX1Qqjlxf*0^w7y}qmFEq;SYXc{iYU5Yk{H2EX(h*vp?-=noQGHy0fJIoT=nd z=;SZB+5x`@{cROQ_a`91P8Hu5{+qv*PQk``*E0f>)O0{VCImILL}v|o-7_&lUwWJ_ zV0*_q>4?9a;(!%ZT{*Y5ZU1w86!^-@;fxDQpER$uvxy=oK{_j!zz)0eWRdb+LA{Gb z*qD5|@WJ;Wb>HZ!tSX}k)&)AtZ_84po#t;Qt$#XOiqeON2+Z#=otyZv4I246>e{U? zNRIoMOmNY+HZ9mnP2h2ota3@~Oav-z-zi<`o=L$rMC{k7#Fa^2^IrI~JQvJRn+vZE zX;S+{hYv^WAqI6{oywvfNLo5rrNx~H&L1gvq|t#^qt+2Y;s<~Am1v~{Oa}DyFmT1w zE{mQExhc`iDITHW&=TuNL=I)*YD0}}E~#-?l%v5r#bsegfV`ean|NbV)|iq%CeOLS zGOpW_J4Wy(fnF)N@wrouvQna863pgrY8YulWtxj+%_pKEV8rqcxxb}U1KpQsTUR%W z)NkNl!n1bch$UhKbCW>^dd7+X^3P^Y!!bDfb<6#+n@=)1HqPHpcp9oDi)xnUq1y!s z%JM68b4RI7TJ|P$>X%kn!F)2B=0(npiZ#%?{fAtXhV%ddy z*nB}+$eX+sky})rwU0d9@iuW)tG$wGqW~u=N#=E);;F$>bXV|iXD;%>2KMc+xZxUA zY5(OuO|+i#)LFyNrmHuq3+PY`rkiICI4JFI`XN&EH z^zis5_;g$wG_8j61^CZxJXQu)>*sPiw*)(A2r1^YX(#t3)1Gvv9+@5yh0X=D#MrUtoC4Yj_1c3PhLt&) z0jK=op_`oP+V*a8RFG4y=ZLzE5n;0*{wtgvN186wV!QeN&bLt5Q_Td2`ePofqF6S| zqDRvtMG;fpTvWQH6MPd_CXt{nUbWCOd}JaQ9ZyJ)FpOfc4b$!%#_uFRv?O(|V05pI zxeBCGmeSCAf9}c+6mhAJs>&jm3yL}n_hZt$S>#P5`EmVMU8sGhcuxsOf)+f_jkr(IPTf%MhB_rUnrt- z{ni{3p4)}TV_FQ~%}U@&Lv)!eEfh%`9Z}0el4sPoW6HOn4zS#9(ll968bP9r?qHCryo$t0WDU_Mv%pVJL>oU z6&2eiy@pmNFx9AfdgINexhipcj`#L|J}j*fjGh^joNLZDiT>;!^nnL(8A(_yL@Mum zToXIc1cga+I%P;~5CKI)F06gTtkKPFx;EHr{pGCy5(2)~}Q z-Su<&uFgE&!kKP@Un4b;zILSl#pb;0=b4Vn(b|sT`6YMqlm_1Td0^K8G+Mp(i)Y!u zuV5f4k%1~cB2s<4BNFfFc$yn*1c5x&%p!3CWdkuopPh6<&Qsvg=7T&Hl;%o`@J-k2EXKp(pOy)<= z&fwngu$F>iK#MG{#;@lpU|-k|U?&ocPCa}x4ixUM^K2tFna2_L15(Aq>b|Gsvc}t& zBqBb%6AAt%iC4EcklDU}9)=Yxu0hV9(OuR0pt30-Rl2#|7|?OcPiMPVW8>MYe-j1q z2Y4=zY~F}9d)FRXzaZAr)Dh{AcEAy$Qx9}EjhfPTxjWJL{c!SpX9uQSE7H9dP&3z! z2=8+g1unPT_k>X}$ALw|$@Om-RLv#K;!o>DJO;${(1r+{9-y0v&>TUf9X!-o?FUzO z)Zivp}wH#{-^dEd7Lw-AIym%h=1W=VoFDjn9l zDwUK)VN+#~k)JAz2IIg=^g4P7+{lQC+~(Ta4z{I?gk(;)beC)(wEaM!2Ujp+^Ek}# zO}bij`aSk@l}$+nomb_+k~C=**Mqjmd4Xh=OwO36jVy;!*aY_{Mz|o^035?@!2Cn} zwG!j;;Ma{2ar>Z5L&h2b1XD=u&Jj8qAX8nhd@Byt8vXzsvE)kube7;aKGnh+jTdc+fd- zhaMJ1S5acG>yP23?L5b)92_LzxayPmal`OoS-s?gUjmQ5e35t6uc5*wa}ClQslJqycvtE44wPhwBcf|+gp|Zp zqgF=U5uFV^dxZAsHc($QV<_fmFKjcepK~!q1tC5mosi>GA@D!4<jq6->^(5(KK))D_CH%nIB{ju$3?&`LUcYi=ARi@nr<(U_%=^M()#U;7&6{o`U?mj&`qP1 zDI?#3Z+|GLuRqx57|Rvv$-Gq9w^x3-#d9f1J(j`_RT(1};@TKVDD8au<)W+Qip~m1 zp%jk$yyb7o;-rI$UvK-4EA(usZeyT5i{mk0o^@Dgdm0R+w^a(XV(HeS za6A)M8*;hANEzy`Ar0=xU))$rKcuPBpbd1oCds+kM`l%2R*X=cYpQ*R=64>e%dgla z^9K4BZ3^;o?O=d8V135ccD>VeN#Pay@S&E-E5;p+OBqep# zP#wW2$5g==N#%w7-BOtvpXuxh4v3^Z^3K0uWS%qY%G{aBjp498u4_2JfZuWkR(Z8+kv=!}T`NNn4zs~5Hl92JBvryGa5-c+@ z)h*Sdua=@P{EMx*Y7MXGz}yqCiL_t8J+1Nnl?z$bn@$29@BYCj>lqb zSGO#(#QnlG@!&(-3a=(}^vw-sV zgzX=yTVlv}*V=csh2H3+8UU@f(7|b8?)ku|GbWQ*x3RqzN3g-EITkNw0`bo!ZYGh8 zhF-N>mh;v%UN$DW8gx$LW;em$x&)E`ti%f9W64fgVD_+z?2Fwaj@jrRGocvw zy_(X$GKw>N@NdC9i`^rky;@| zLyILl}=^|=PY3m2Qwj48U4d@<$X;eCV zaN4ld<=qG0)e8})%Ionw!45;K4xF)nJ1z*Vyq3d2 zx|kr+DAK+HJa6atOwP{9Yc*b3HZk0B)*YM#yfM=DLNNT1fD_ASL&^TD@1U@U^Eemv zM2>%|Nb)yZBS(r8^Tf;>9YV}q7|HYw%VEjQsuX_jng0k#%5v@Ap(i0#g+%2l?9-@a zYNL|1!eVErR@yuE8Asqe22ya1RWukQN9pdn&i0`4SF6Q(5iZi^s0%oZf;~9oT~;pB zWfmVVE2+{17pkbS9SH_&t3t3TNg>*7H==WVWQjn`yujo3doNuhPXm8Ph<3K{uiNLi zx2o|1o%!*W)Lj2|%z3B$CcwwO=0TIYL#b;=BDnb2!7)s78Bx2PlXnv0y8>9P_vFsf z&W7pORULt=>_}tV)^!e~`xVNK}5}i@cNm9o8w4^ZMQGQX&d0*(lvit+ra6uXwEp1k&V6wh z$Z1ab^B*)qC$Clf#ouMv8H##v_hZ*Y`)?x*&iz)CEf(oKQE|0+F*KpF;=oe@v^@H2 z!CJoBzlZl*A_D=I&b%+K#BkfSDzERHdm2suB7~U&JV<4B6?UgvE_&k~CI$58L$VJ6 z#N<_t>&o96(QbQFjYsFn^f+d*_qGTqOGKk4@ zBD_1vy0&@S$OYFS&F##Eb)hCl6gjvVre>0qKR4wTN7nOcoj%DY3cZiB$EEdO$#K^Y zB5Md??p!@YTYSDgm0^*Z{CfMGTrjozeHJR6-Bp*~)I~L23eCv${7HazG=o!7l&orU zQxC7Kbs={yy0?&*kg}zo=PY(wvM{!i+uU3egtYxFjH9xH2#+?Sl23u{Uh3EJ?y$`e z_A?pnLBaeBPI*EZx$t_yw>DIfxqD$l-x!{+G7PWYn;nfVNvQ#YzzeA8)`@#)gh^dkx z?7ia88@}{Puk)vnuqVv$RGxLwVRyHh2HP7)ZzJACYwSa{g=LjGC6S!abWd8y6=jpV-d|CIb-|` z#uNO+7s+MdPd~Wwx*+x-SI^!UV?_MiU-6#Btiy(wF?GjMLtC^WAn&7(7p}v(M_gDr zlUO{EQ#R!;no+mMfko=0il?)`+!#kKB9&_mqBxzQD{Of=OFdn^RMTE^CdiAcaKduP`l?Yabyxw)X zb9SzGB8rcIHYr(ZU^*CCKkJ=SzW3Nrj602u?-RjS-BLNJ90ipQtNbWHJ}yHhNiZ14 z89zoBZY+k%p?-luEh{gUsF+>o)K2uBW1E%VfTfKAv|g=wxD=LVnx7GA{`9#D|^XLA5Qtn(z@@EcMDI)Im>? zubXKX#+MlIXjnx<-#iU>I5^Qa^QZG)0ZKa2D6fS2hED-0+}Y^cK?%ZSJ+Lgrfy4CJ zw0%Zsyw#T#!Z*Ugz6pX=Si01B=u&U~CvzQ)M~oj$Bn zEqLLrF{MGz_ZGv1$J1)uK8rCswXS$ymipQ*eOuV7*F;Exw}RtT{sx8uai!4Au5Kc}hT{@CEy=Qev5hc~+Ux;2e!P`~!7XOn>|^+!6VLvYD2@|T(pvU3u* zXdgn#n6A3i0jrg$6$24eUQE6xG zku4U#{qcc7KoB`u4k%j_V_(FH?mX5)6dqJPjcAl@ zX7A|gys5ADAD=0@Hf)nCvQ77%&gsNFPRH~4$5g3^WLgNTaV7(+vZ1&jy*POZN=A9S&j&sYfcF|)} zV&>YL*s1D04Vj0Sy{b<%b+b@c=e^U0Qn*fRieXIWZx`oa)#VI@2l{Wb9}~~7(TXfr z>F@LUOm~Bphf=sF?LJaT`|}&i@aX-k#TcL&TIJ@#ZiMRBasPJ-R)(HCM4S`ga~x~> zMoaj#+|8+BxSa2tWt=xC%q1uCkA*!QRGwUDe2;YqlXXadW11>2a4`PaXoiTHv>(JH z9Jdq+Q=KvLM!}#ZCTmtU)G|*QT)0_JT9{Pw2C!qb8*N@I7?c;OcYegWpDrg(nA)zo z_y!~;zZY~@sZ2I}E#8f8BS1@pW?4QC}R#yq~s3Zy!1g4aGN^%K-dXK4XyKyjl;mUaBW>TPX z5om2Ail0OSJrUO1NwiBLy%v~6>*eQuMecFAOsaxBEe)|WdG^l7dHxw|k0!5@o}zo+ zMe>Ouwn_EM;Li_Pef}?GS5l6D^Y|J+(yD(H zrAV^M>4Ds?xFx}1rDprBBy;7JUZ zC6=vb<0VZ+5nmwL#4%N`QYB2VhC7+O8WI*%OZXxx z>ux1ySmD8|FhgW1rwzYpG*w32ku@Oe;Jc_X7pl_LY&xF$4P-8dkEMy~Abs+@*@b>_ zDY6WWmZh_Iacl+V(0VmPj6&ae7Tp_;c_g8#h^uc|f{D97Q`fy1z%bJ|6 zMhyO-tEIP9K$^PMe4XDi;j>g+^~MaHS5%_<83euO;pnof8Q{f(P?kmIUrUQcEB zK{_g0hbE=g_f|`xlO#H0Q|yuaMT_kt*r>d~#TbhHX2<_^CItjyeM?v2rAj(`7;}X0 z<^^`Nf-m45PhmGXa4@N1X+QuHz;+@z<&bPr;(GbwpE#VXzP~|R%W|wQ%ktO!O9yMu zvl$p*Y}J0zvR|45#oKTCT6*ey$Y4D#@l*KyFXK3U!jHbJXqNIE6|$_iM&8$O=+Ab@UmA}Yd?zw4Cw&tr4^xpi=;?egYL3v6k6W7!Bl)E+m`xj(rydaa7QAbY+jJktstY|a0K)tZMZ?l;L1^E}%Hi_Qw#bWCF`Z<$F1(xHlXPwQ~uUzKyx9F7!B;(9+3!9 z$!`}lH@z*5V{BR@#MFgMF<4s)JMrS1dq8y76)fbhgnjeKMdNaXCdh~iRwk@LnL; zv97~}P30kJuOn^M_BC5m=;aTTeIos(#Feir0bh4mJx7B2fsS4OQc{ceErcq(kRqO+ zCV!K1I~6{npQK&0iYGgdO(CBb4XV)V{*%rspnl*I>FWKkEijZzVV8#*NMZ(d6^xe6 zzKijw*^*_BR3zTvrdggP$jD?m+6eTfah`SZ zp1V=SwA7O@GvNZErhYeez@hBdpo;(RhSD;dV5*+5lxb0g>0`dT`EA@jv@1P8Tzrhj z>SZ`tE%QMklirVMp!dbA+;H?T{VfG5l1)8BnlJ-c-uIy0q-2%&=DE%^pS3oUV=MH` zm)JJ5dwXy93RhZ((EAWP9^HL=R9J{??MYaRORgCd$ioUpDolio;5%$CJ`@^>;?F|YH%o6LSvA)m(tuNAx( z$rD9qN+{LNPH^)^n?a_+wYdGfYLNZ)nUA$_W(xRrpNAC?P64qhx#6A-I~o}C^rpBU z#ih!R-63atmL=F1PRX@4TGterPNcjVoPZs=9kN8kaXp8tiV3VoZNuzcmg*)+h!T$= zK7C3_A!faeI2^{&%dMa|&t0aC+agl!f zC-eAGGT^@IjM@cwrH|uX>S^I^v5Ch2QFy4m7XLky{A~Kc50_f?JnAArs7>30;z>%$Md&z2JenFY%7{p=q+)zw6{({_~K-)cU1a$5!6jJ_{*`wj4cMXUTo z0KXEPy_qNm6}evaI8>Fm0m%K?0KlVHRe%Gqo@Bs#5cqlZ-@mT9k8&Dr$fi#ehOF`v^v?kkW!H8-J;MQbgzdQZ_ZFZR z?*WAEP!+h(mcTLrXijcH9Lc}Gb^ssbq5kw!KA^3r6BwZBJa6)GWk;{)R;PI2Ya(2wOkpxqzI4b0|{^h@`r zqmLj7m`=cwRweYDa07tH>^*86OxWhc65=9atJrEg8`=RhXsz=Ys0#nTK5jG`i8GD^ zJm&t+vsZuLdH_BoK_qN4)AJy)9!{9DF2;)NY7+pag3e_tbG8?+@(&_`qa8rp;<9=( z@v*cgY<9ylhD{@xZ&JKb1Za}YLCWOzKRN_y1n(39kk5|VN0^q9fid~-yB@%lS;|CX zvia-h+N1YFOMN|H#NkH*W7(rzJ)NukQ66r9)S$$TTcHoYv103(bCHtEGXA$-rl>kQ zz!!f!2TTQZ2Ec-xu>r2i|IlI7(I@atz}7Hle7A&fMk2FJ$@G4o{{;d%wITznbe(MA zC|MkIs88c5xcDK}dRzc;q=KRS%nN8D@A3WijQR(Xt#Wd94WlT&s1K^JblNxowl-6H z?YenGhZ3jO_Hn0dqz7*tITzfwwhwr)r#-QdP>SD7Lzqt>l0%qi_Q)TUdi+TJVvy29 zrkg(jK+`D^4$TN@cq0Wy<02RUPX9Jr2G*1$nm zB^e~gUf)}eY^2Y8{0B@Jt`0i;xDjF#jkU-ODTZoWgLyd@yC+)9TLE8go7OT zk;6JQbG-5Vu8xrc#ZJUL6`F(F<>C8(qH`KRhRAOmNl9HSBLU|N1EiE_*ECZEDTxu( zNZFo1o+%mlf$zL7jb1^@jOav=cR#iKc2E9IHoO-(`TVv~X}(8(Cd z7L7Z5H+o7Uk>-+JZ-7(v3_7w=6v2V@6Ps^$O}r@%fdbI=IMs=36fxX9u{zcnz{&sb zo4qkC`ql@aRu}*4F8vpDzxiqnNq^Zs0&sH%*DG?g%g4LZf_QHYdN@1TIOGFF(Fs3} z6i{lihHAw3m)+tS;1%V;aKaP5S#a(=gDZN@vZfYElOa2U#*rUF6fv^NRxrc-=4NdKIM_#7qvVf7q2{x-FKRAI6lq1kwj3{FnkHH~>TT!Qa@mfe2&BV_Pz$fiP| z@&>R~X0;sG4^m&2TFFaZ0fOV;`?J9^ZHWZAVvEF=(pErlt4Q}3IjnlQUHGuNak@Te zX(nH{t4J|f&8`Z-zK=V1-vA4zw0w|msL|GeP!9<~&)flt-rV-x>KfUtLmd4g zoY6z(bPxrq)TjS8vC(Lkz(XU&AtWC>j!aZA_#m11TPs#r+RXWVO_3_9Mh=btbZUN@ zX1zV&dMzksLIbOt$S``vYuo|dSHRP#IV)a^Ar;AEQs$Aq5gWpJ?0DLBoqwGTZ24{# zY-UDPP8IImyGU*sxP2B*T$^vR^U({|Tr_q;7VNM1GzwsgVQ%jV)b%Na<&~7<->hAR zGm96lruDZuEnQ|(0tUN%UsXk3RP&oIr;S*BqW0H6j~whb8gtLH>665Ho#?T`fVT+H zLJz#Fton^s|77fVw|1EDsxScl@$c}RSo$gab@@%v7T}AG&aZE zno^);(VkFm8g0%c-GnAVV@*S@k+&T_>4gux&RBt*r&OKU5*;NQMO+}z_R9_O+WZ-K zy)xq=;-s`h;_eyylx;gUWGrWLFD^!b2Q?&h;5%ao*hj_N1OI!e3%82yiFh z4Nu*Ct+|)Ga}AOSZ#xM%*!29wKG=PX+}m=HgfH9d%}q5D!mgN-lj9@2%RtbcQ#9={ z|Cj)>Chdsk7`BZy|G>y&u(*adzZcidWam>IC6{hfNuhagN=(^oFGF}!gu7@Tt1J=WHW}gcN_;zfQ?#jj4wBae=1>l zyGD*L23j7nmrNv!jm>Xh7T^Mu5yp`(>2?J7FE4Q-2{zV&bT>J&+k9ksNU_eKW9Z*u zv|dox{OR<)bR|FU@Vy4|K*z9m5XBc2eMxrJXpwJzXwc zycPenp+UGBDfF4eHBuf7=yhL?zdMTaKFfUq25A_}6(HYm)|Yd71&G;VRbzbf@B&fJ z9)!#}iDgrw*;4R#Wd=^byO9D!;h|<9iX--MgofwgSQc^?PV6!ITWS`KPipTUPCoV! zFDa@&Ls59?om{z`O-Uye9z7HR%P=;Qq2zhH_TRyq1U1W=IN!y{F>8vwVUAIgYzXTY z6pjxfU4VZhp{jb`zVp&>2!WpjGy3{~c>1nsyJ47j@0AdKy6^W%l?RBIxg4FP6kV5ZT z41W#)b{o7gvd-@ZG{yRgmS?}VFDYvqqn~Z|%7i^}ir^@Rktxtye3+B_BfM@!n!wt1 z{H{l<#Vr3hnUB5dK_pgVRYSL*8St_MuAr?AA zKn@3ZE(Ltu^aa2Iu=xs(s(~8@nm^I%c5zN6MbJ5%lE!$kcTx@PIi3yZBxt{bb;ZK(Ush zH}~FDYw?S~93B)yr5re=l?f@ni9GK&ee1vPPyG5?Y4k zsTU>cXUW%dmq?BKy^sptl&FOgQ4amsUt^c9%Scf)sk(;38D?55TKBe(A?rYcia) zh&^-ZDZ^feY4^`@fTPlFh^>6O{1uc1?ynnz7?~4(IVk z3DtOC0s1Qg!4MW6mjK}Q*~$|#395YY4!YUSEZb%{Xn1j1)?uMWn&)LT)W+})RSm;FR9ss6BcQe1p!1?5ywCAsE zm&o3yLYqfSH)7P4XyGqHqXK(ld|oq*S-4#-rsd)`H@5wa3;+8``fep!*>%EGeET56 z3cV}EurLs{5v%b5ljmdy;Oa=noT};R|MnSH)PeYlMGqxw$qu-fC zk)jTh!(e)<@ftLaHnpK71BHYE|(B^qV z)`5wCU1iYB9DYhcxYWup&mJ<_u$OT6lOBkKwQVGa48uY4Uo77vc_x3;`B2&YLq<80 zFyQdUU4+F31(T_;oBkZ7m&7T4e?!gc%7}##ml^A6`;9i(Fsoqb%rD|MrtVLOC`8q% z`Bhjfo1Kqn0n+EMaA#(&y~a`V?W3~!#me*`Q;aA7%X$09A1@{f#;f z%me`wME0zL-cy2aZ$!LU@?a`B;t^u?ohUM%ZGLI8eKdnS(D_dOe4n=x5Ebe^MSldG(pf8Yxo~ol zmCCBfl2QdcM?~`g)Pgg1b<#&C+65(2uqIJi0-p;tvJUFn-D7GLQ>Wwl(`CvqGUFua zN?pFi6qjSlf4z~CPblB#B1l-RD}MPtJ17ZiYmw7NVE4zktF7TsN&<)H5o!LBDC^5b z*bgPUt;8C~m|a%B`qOc+Mv-(C<+0!uDB5~?bK>Wj5C*tQY4*mgGEFide-};np6B%4 z>vO$&ipXRq7W_ejU#Ccb+_cin2hwdnBNT;3@M%&LpiU$g7rw|SVMY`zYT(qi9_f+i zwe>oVvjQ>gRrLz>^Xco3%eSg(B;MgTr((szi%~6x+;>jB?`Q)FH#nT#;(SGBfSY>w znl(0pC~Vn8ZI}DyOpEM*oUG_~WXWR(`29f#Xu6QU^9kQwK&4u%(tnZS5r{Ee*xBBL zYc4DUms>902=Uai?QngG7P3@6kj_KnXRp#!a%gXGdU<2G?1eCy`B&2Yq1ZZh2pI{U z!&v5RkJb&Vf<44_16;ey10BLXrFR5p##%qKkiTu_(Mq!tL!vAImoiCjS)HqH;ym{8S%8 zbq{i=PL4f`xN-4r%weQ&xcU>DtN4mt>$3PHJ)$x!wkqDU3K@`%10tTxc&pRqXvc?@ zqVK;`{C*HejBdmZ0h!sb-JfR_6fUz#oVfn*Zrjhlmf=O$eQwhxjo7v+c(}e-%W3d8apbjxd(gpq&2Z1{^>Tuy zT+obr+`bJ@*S5?5w`hmbQzCPnrXwBZ0t1mb1)=di8x{<*4;9h>*WK?Y=EB}Q&S>r2 z#-v&)u0c*jP=$Q5$w9O@JH4iQ*FW?TJ5<2-D0M*O=V!>!*B?qcY#-njMqw|S8R!hB zJ$a{}HEWLp%N}_eOtWWhwm=#x||IAp4Q z)rh2GGuQNr3ZsTiPz&!(1W;#?sgpDtXO1M(B=-2Q#WE5-N^E<)lE-#Vdv74;b8*|C zoN?$nr=zD@6O%ySY~ZD2UdIn4Yf-yy7AJ!;?4chfn?wjiN`%chlvd+9N$72KonENs z{I@6R|JB??&$oqVLaqg8H@SacknINwz&2x)@a-iVIX3=8a0SrHwxWkg4XyYg2d0es zuWrHbdahUfg$`?bf(Uq+b6utVj-w^{Q|?oV{-r2MQgkj? zAs9JPM2Ks}70~Z~B3`sCzjd~Iv+!{(aOR`KH*?wK9tq#-K3 z?wrkf$=OCDq5k8yc6uDDeVf$pGyxs}-C4PdR$uF-8Sjxnr3n_Abh{xh%kwtO&c#@a z=YO-WACy@Rnlhj!{C6b*yZ1__Pw;9?+K!I|k$4-kh>C(H`>NObhq>8UJ)k_BM=h;g zPKK_C%l$~)OHD`-R7yj2Gww-MQqJKHQa}Ncp@w|DS)<+wgJ#zc$Zd6UT)OFl|IL1F z7WlR=EoL*d)6xVP_?AUPEc7qCN5bcqh>KFl*!p(o#P{!IQg!-mLLU6-i`-C~WR@0O zqrw1|kl!ph)8!$a90pMy^+1UQl`pbR2)DGSz3^$0uuK?p+B$2X)91QY?skPluM{5@ zRXHc$QIN5%b+S47)j^BQ2X3O-urzr5gT2vv>3#*vgibi_JrjO2Z+dd?@Nf1v9(ZmM zxlHd;dpb@?hN7=KMo%?sw|x}p@f zCm@Y=lfZ?EzHsio*yhvaWAOfhi;S?B`T$6!AUz$NliR8Zrs?uk|+=+m}7k%bLoFl;J6@SqvSZ>|_Z?lH-eo)@6PM^rFHOXtVBeRCx+Y zM*N;c=(xLn3j#`D(b-=7L1l8ZXJ^jw_;v#QByO_cF#;&z_uBD#UQBqph+^R#ch(I1 zzdGPlJ)dJuknelmcq5Fi>n)tZAC}U^$IfeJtgM7<9Rlp6bt-OsJ5!nY42A=ZeKSQb z3nOtwFZW}a4U6wqkfTi)^nJ#>h>lvb>TaJl(>keAp+$U7Nh~*rI(1TJO_5`H&QnHc zjm{=*O@#i+!Mwtf)jYSNwIToW<_FrX2BWWjpoms0RA@MNdgQ_!BuIodA_=|RVVdN7 z`35FL1oGs;&~lQO&Inooy)<>F|00_1%HxKuTtc-gneZA)Gp+|p-VtuZib?_LBm{)g z{Cy=)=>~=cMYX?hrS2fTR6-O8BF}kW=!`iN##M94hVl;4(dBY??|;6pJu}Sr`1gwc zz#<_(0-W@#V@}j7u&g}#wBm3Cou2w!6{^oUP-H|EHYtO1{81F zb*ZGiV8mD7l~I$QBp#kvr8A4XFgl(E<0>X z2l{0tdeED_QYzo*SIJVn6mg ze=vlnm*?xeJK1^$#m}Z1bG?m#9V^DA13CwwXm4tNco?G%deJ%8yeX?%j9%dN0QkuL zur|L^?ljNPA@kEOrA2>u6x=O`h+%WelDXRV`OO*HVRRr&{pZxVM5kG0v-lUyTyLvR zI6oFxQl9^5Z4U9Cv5z~cNIExVrF7!{;T6w-ht4fvJ1uzYb$Ljjha5Yk8N~wW>vQ>C zo&((@vg_ghq<>KEJBy*r5mJJ?6b|RMwcjM$ISG3dz1|kCGxJ*fM6Xo}&)AFJ?=l$& zqf4<4!wU?{{g*eXD@N^1^bBU0pkrK}Gmz8+Mz4+Ym8GAE&4?NK;}5kFi2o;RH^%Ol zSgNYln9CSVS+gm-TwPnt*fo?Sc(Js4a3U9Ax216+wyEv^$Skq1X!q1O?^tOT!s9q9 zAg6>Qvh|ta7um&+3$U`u3c1Gzusq8h17-#gyDWZc#Q}9rqRhpgF70K$K4=C?LSvn~ zdfxG>$+dYvy#q>EV3~%{M-Dwphf8(oH)ou#((6p=mv;!eSNa4?P72Jg&-h6qltpW^ zf01cfWj7eYc`+B6pbf4QmJ_-3_E6KN6P!}zBCnESwuvPHa@5{G|CJkic_aRQ)-7!b zGp|~V1ht$`zIj7lxt&c%j?d09xfN2+R;mWk&Bp>JOg zw?Y9oj_{8F{taI`Vk?t^HDIjZhg4PFas&MR`uB|)!95FjtRh~gfNwK^fD<$?9Gs~4 z6J=|f(Cdk(3I;if^5{cD%K?VEkF;N~>1b8Ihy7qQ1WT$UB;jW05u%gi& zMDtp3Y(7BDjGWuImp#avb;C13%?8j8AdTqz-_kH+eCzjxUjpuy_*+4!3P0TfY(ut5 zM`iUvs<`Oz!SLdzZte3Hp!&jOnNvRU_ap0)y_*%jO@R(3x0j4&cd^}kIr$&IalugU zXsZ3-y0*4WosS!+v@&=$VDgyr28tr*Jd;yIIJ|}TfHn4mW6al#6$M$YQ5K-#*nPJ0 zh^Z1y{laT4@H)?*Q^p715n^4$NsP>Wzy7cOri;?F9E?nvKdNTwtAjt3gNdYM&{c7G zVY7C)oOJ0(Bd{)>1o>J~zPyq_fxYQGD*a-pBVp z`2OaQgRJ|0uIs$chppXJHv5_1dEyfm>2p)0?Of53>S!GI16O>MH=ELq%>&|JQwHY# zX3aU`^--nzOtLd^4VX6^DjU9P8U2$AYdW=B#yol4GDUsSr`i1{ZN{eGvN-+>yaVaq z_4hIv@K6~tOk$9Fn*Rl;P`Lc(Nw!Gq{xMn0QGwUh2vnavl1BLw`wQP_tEonGV1}cYEwSBMQ>H+HIwN#N2H1NIV6(vAyK<8uLBvo$AzC84+BC-Fy2_Jdr zErrXf>G!%%`gWaG;xLa1>g^CDPGWj*C%Noi0!L1vVIRj8;4v0hx>j_t$&&hNdj0Yt zDP&0Z58$OQp9OM$&_CZij43sZFP2Ck=9lH(q1&pNOD_PGj%)dQP6Cw%VJNoVK*|5o z`$jy3?YmeqS_v3kxGW7C-kZ<5 zEM)2N8Qb5qVp9=6;PX(&0<|EK+cKQZw$wV zC04Cp4NR(Re`SCqPbQX{0?F(E34oQARudt$>9qz*#ze@gqeub1;o<*kxO`3Z-=F3~ zUg_L@)lfhmG}+>4xO-d-Q|}99Imvx@zUY;6&eX~E=*KUFC1s2nrJM0YY1PBqkeQ66j)B`c=IOq% zD84a`zn>>pce+T0p58XxBEPug8W9A)mJ|<070%c%Py;u`|NDZXXol|et7qTkv@-+n zy5_Rt7xH%~$5r8H}=+at&=- zhi0vqdzKGWyiCD-)bq1i!gSt_AF7pO)zD3wY8>iC?xD2bJ~gMo0Ox1cG|hl!C)c~V zX&8q7r&HQZipvs({K@F#x99Sv{_RtyRI@;BQCLeQ8^bevDSn#VU{>CimW$ zNGhG0(4|0QtN}$X@weVjD&IYI#ga7bRZsdj#`0@B41~B{HGZG(&>42yvF%KsWyz3G zhj6$7HQg8h_`xDKmciy$_E@3o3moysu(4%{DB*3hw=s9U^1@qQDW6K%tUtXMzx6Vu z5nJB<^x|!CB9#M{H__+fGL+iMk9s%u_(g6vg`!`0(F5)jQ&%nRd9L9|Wuks8rotys z1^;gMt>Uag$8=lYj%jl7;a%mod=RQ_He$M@{|c|2a$-!Csejlsd1zJ`rYq|k0c`%c zPpZ4ncA)q;<LuF%{s+5(f69=aSUHV>3C{f_(i zs*NPtQCnsuo&!>B1>`{K{G|2h(WmdhIzZ*x>c0`3&5#oQ!9$odJnuTz*)Lz*xX3B$ zhyHE%-)X_jMvjC2)A*1cRjwLD1tY;CnNZuw`2C%o+?Tni&n`-an-YJJae`m4AKd>} z3Tm*d(cYnl5gz~f`mg=U9x$R$`~ygbohHpa*WN>O?Xo?LSEX2#3&T^!U7mV@6=CQ* zZPT~dBK6o4)})1Ou3Zc>td!w`?R7 z?E^5PDRjBBzXJw1c5j;nYO%XRFGFj#`X5&~-T>>%$&3)2D1u(E1(#kdhva|((mxSP zavs3DWA*@C#8_SM8?$A=EWSI`iTf8|29}uo{&Wtt9r_=-`vC&^%^9I>($Fwox43EZ z!0UI|@xsdcn+27=MxO`5Ez6yZUxA@au52(dP2xZB3lO84Hq&9DO@bKuRYxYOT&yH> z2zz*lW~YE@Vt{H!`BEzz9ET#OWklgh44cX}ei#Udn%{P~qrCyBYTmJ`ht4h#TlpL* z8w1hf=)}8Zi%Q^J4N2KN=BMjr4fJyZ7>fcJ5Z(9EndA#S(>1m>kxm3y>4)$DEBCvU zx1Qh3MX+yj1;}UAZst{|3ps+=`1X3y#*Zr#F;6Z42Vqb^9vv@^<&r)TeSW_3%~-qf z4s$3#C6z7F@{5AToM8hd=|C%2{RP04nLC?pC#rPbF#y?50g!`P0c0BaS$ok57P*V0OeCYQd)UPXH>?;Zp;c zyM;96_ziCUZ7NDiw#*IzZ(?E8Rc%8{&qu!v_I`o?V;w;KeqRz3Aq%{EQu%q?9gQgD zLV&ul+HhN4Q)%0S*|-gC`l64UaDJxyxYl;2RxC3QF#34WEyg94V|CWw}s4HYymeZAN{v)aggh7fa3Bjk5DAxK`Kb(N+(Z zpKA)F@G&jGY}`f0I@7w;0jNn~$!FuJL2NDk#0V>njxjik1ixFpDx}=NYd(MD| zyiOa$EERog|jI0lAdjQ({=Q{3fkzdZTVO)-2 zD`HtZp*L%J{>*sfYQK#{2RhKv9pK|Xl)H)CopbuSvjfV{aIDs9vtx-1lvbQPk-e#p zL7&M)cCUhDAGYax_WQF~-=O}QUhR050aa&GY6C1kmw_#)4-EnQ(>s8^V0^F3fDb9z zpL0RXA+?31B26(SK9{$7|L30TgGukb&; z^wH%%9C@3(>6363Fe82H2fLYMma%Q?$2TR7b7@mSXB1YW;@6 zj3?QXq`hTB!#}W6WMDxA&>cP#UqvKbxXG}>ap`YQEPr7ynQtZFwPn`@Q7pY#1EfF{ ziiPwl-W<6}Ht-|r7hDIdL9<2mw)wFV_T*B(gQok87(x!exb;Q(0O(oDB!5&e76Bb% z@!cD91PasNqI`kzSVyuL811|cM4yUm*zUDrhE>Jogny679=6h21!6Z>0lR6o*|RXt z{#2%6nccDz*+XpYdb9{q>@A)DW(%7bc}xP3LrX*0+>a9zrW9C zB;-#kJ`?TxM%*`0*d+IA>?XOGI18RJaf?x={Qmm9$Cv-pZ%l#X2I>6J&Vt(r6pNm$ z-&Mr=(5qnB90>imNN z*Se&okABijNKnZ_7BKZWfWcI&R^~h)DJpe_p~hZjGVV`81HDz)H$qM)mEj321tX2) z3zR%1ff4g+-0(!+$uON_0caI{%BTBQm`feBe|zh2sQbfHw)pBOg_au}p+-#&Ay!u7 zdM7?$!D9Scl}eNyFA;x8O_R#%DHr%BB^Qtn7l98}qk8t-o}Z&|O8Ld>Yx&#Vn(Ged zL$~=6jKX2Yc_s!(OFeJS-p%v^LK*;qsdD-fQFl;x|8UeC3z$U4X14$Vr0c+~ze+h~ zmj?x8YUs*?S}c2UnxJwD%qz}+vvY3%I1YT0a#K@u5?Z&~olh=7$qGz6N!Yy!2;!zY z7FLFbdyGAIS_cXXOxd%|N!st^?Q;UIf&U{uco#EQ!@$T_f?Wrtx7_iszP-7ka8tgu z>)0e#<1cst2GVD*kDt!D15@d}t2lUKll3Cn2t^6I(kDkfT3flDOMYLzS@*?MwR!uX zmcLu;YSEDJ@Eilhcf(QJyp}G>1 z*k!^>I%G@$Qtm7tSr0Sw?|x%p&IWZUl1_n7NDg}(rnH|Fy3zO5|GygW)W)`XS zSWU}6bFcK^xE&0J=LJ-?buNq_VoyCnSSFT{W`uVWZqwY*7$99tl1Xyp#1a{lX~mf6 zHmBZTh#9+o(1+qmNm$5_liBf#&H6XosrG~(Z@K6BJKt7dW@bSSGD+Y5!#tQK*diT{ zYy+mTEeLCK(+mV1L;MSiUt0S0DLHtVxfW6e?~`B;@Dg5Aw+w)4_D^z--e)ytuK*n2 z#%M>tT6VU;Wya7?bgiJ(W>rR+{8^fZijci6GA&l>i*Awmgx5*(4V#|nAvVz2;-w&= zXKA~2MPeN?KER~kuWou<8=^jACz|o=@_39+W(O`h49cGMu7s44I(#I*m#O1aDL<&_ zI%pchBwBRV{wXHgU>;A`&LQ()G~VJ_f%E+W)n9O12a~GE6kLcO@G;->{9=358eI^j zL5SvPvp`61memY1WCu#?o_|J0CqC9lik!2G5r`lX2Z1B#H)XuB{f5G(PbFwx>6Q_s z(hA^)9sdfr+K*J8sd|VrQnDMJ`|~;Hb7BROtg$C|XgJ_}i~)C9@U-7wl%rzQcl-iH zQ-kvkQkmY)2pi7XSPdqnTLl~XWoPBtbYT-tx>bw^=R?Jyz>F2DGHFW}TTk`NatB}1 zW~VWrO1Lfy=aa2+voS|+g7xnMUAdg&5u;~e^V6!rxsC^-smg{J{<(raZm`25vLmhFl)=aT6^Bef1UWRaj_x_~ zGF=!`7FLDfrFHxDm;K&)tK$9LPs6)*fnKWUnCnHg`;q67av@>E`6fqcUVs6){WXQXBxB-H~74II^sh6NRS9o1;k>I`DgWiIM__HD)VvDC|tXg zKo-KD9!MW7g9a@8PGt3M{OBF5BQ7@oZ4^Cg-%=N12q+ceb)UHeP}un#2}w{r+5zo^ zy3aHAy258lxi@4A$h>y0KN@ z6{+j>Mqk4mgw6q$lqK~KExXEgk~l;~o!pB|m8Rd!6u0pYmUb9Y%Z{R{D*6p6?G~FhHExF*iTPB&8(D9^Ol`o2R zAN^xg5((I`EYaVOEM9I)t_}3r#nt3w0=}BBiW5TX4vxw`Zb+Bq42wCz)2_)kb= zC**{S^(|r)ARQOXi)3>TuA*P2X67w#S-MTSICRv_ZV^H@e5PfKRosIVM9EPD{FM}m zR@dL|0sFdNU83wln4oBit-d~cF)pTr{VB>nD7PnZ0S4JDOYZj#P%*(1u{PnMrbfXktMx~0lydJK+(q4)Fv_O5EPSu?Dz}05C$Q|PM$T8sX)dC1bk4img}-?{mxmy18i`bbN2^p1sX$O8GzhQl zj?H9~E-4GkCUpTcIs9Cp|2fHhF6M7J=3q0>*>CCmJEt@@?-2FA%g7G^d`{K@)@|aU zG`<}7O4d>NOSa}YbS!sNN%gKg6vqW);ub{3Me^I{a`7+aIFtlOuNYGvNoSN#tA-Zf zPP4Tsd&mDaMNSwg=a&17497a}nd)ef9ODMRwnp|}M!TW{kRln+=4lv{ox5DNO^1tf zFX+Q%P97=>rI5YU`GaxCYXGNEOS+zrqvt@O+|Lr_>)A_{nR%GzCij9FiHoOt>!MN1 z=F3;Op@sl4UOvR;^Azg4}y5}b7YicKZAoQh3Zeczgg{{nYp;dfv1WI zG~S)r#UJ!2>B=P8N+^~m{m6W$#L>4uXXDrMFcgtB!agQ9nY)NkUKZX#$r5&4ge{zI z%Z!ziO|dO))?f0!RbKWr^AQ|2Mz*-oZvaJ{VD}YQ68^**CTe8$TcQ&>ypeNjp5rcW z2TpRdI8TuTySDNkOSOnZ+)YC}bFtznBS?p<$HfP%6u|7J=ObLHkzlqw22nnHergwh zmOZx2wt-VAU+?=ttOijOBl&j{s-AKY53+aI);dTYBRUd9Ivf&hb0STHC15O<54SII z7RQDktTWK>oc~HlicHa?G`};A}1qQp?R+cXr&EMiY;X-5v~>DJ8(AdlmJUc0e+d;0`KB;(x>7 zpKak_Z4FIIBXrZRG#3eala3MN(Fh|p9!|*CAzdO3nhBQc+(+2^%N|{D+G0vHpl#LD z_ZluU#*wwoPM<+IYgH)r)ta79j!57W-LDiY9zSBmG4^#Ft3i>=yMfx1_%TE>G5y!P zoGQ)St_iWf;In58X@j(5Rwyfo_0N1&@E}e4Q**=I>;Y@)G*|fu2PGF-behQ@kSH8< z2$l=az&w>#`{>s|>l&{4H?|6#eejL-E=P>Fv4T_l$vtJg7X?{l$8>VEapxo)P}@5u zxyb&=lr9Y8asln23hH%+&s|jg8dRG=wX@ac7x(EWjum#hSS~iD>wrO&$c+DkjSS7S zBd5)`!%0PWv<$&NuO>;kuREoZ16I4cnK!7rI}S~9Q)9DkqLtHNTl9Bzzj=HNsl}IZ zD1G+rp7BVFZC+~Oth9M@N%*Fl97_q?dUb^qa(B}krA*vGY;ZXe`Kdrkp0!iw0zBE5 z6*Ouab=2-={NCZxykO9Senu(rrCC%!PIFJAn_y`7f-+nx0R5H%^8RYn?dVaoeWP9r z3rZCkWxM;*NF=h1bJX_cz(ycJc3{zsyHrUdMEuE^&C|>G)UILtqO_aAusC8419ls} z{lVuc`Wn+|rI<%AuD8z1^9a}t{P{c`f)bVJS~@&<=eS+1Gv}pgATao=GnC9)tHp`% zAJZAXp?}n$-Q-`SvYhlADwB$QgoJcXP``&v_KRvVZt*~Kc|fLU<=ub%#_>raziB-d z(jBjIiqpzY{1Xm`93UEk+%|)!c=(-F3T+vS7FqRUWQ3XZ_P(;@iM!BR7c;fZ6Wz!D z@F@GhsF=Mpmf4m9R#h>SXRmc?ZB*wYkmuCAAgVkj2@k(_ehmsIx*OQvhg;0X zKTZcB5Wpq;M$ltX9d$hXFw}00N9Y$!F6#ve;am`9|3jsOwFw0Kq3@Dd4SX|$?>||v>iSOx(uOH=CFf?1-TA+ ztMO7kM}F7CBWbeweQu$45jQi{s_w~l{Av&qk8H2uC^4|&h95enj9qqQcZ9MNWnNo>|%$5%FLz~#CY=f2`!-K)Ea z%S)x0wS}*389pkrb(l=A)RmL8TAeFD78XjvP16iW*s49nh1&+kw;dAX8SB<5tCHzTCG=s(d4f1AQdn~#qdw|2UO zX1``g@->{8euC`mxcqr*QoP=(DOMiqcf!V+wYJQ18H@LnR*HNf;qkV+K!RfL6+C0B z;fp5svEsYPgvCDNKB=ELvq_IdoZ+TDyURssZqd>}M^CN4BZ7zF6WHN=4d1==5V_X~ z9xZeq0X1oNG$iWAsqOx7VLf{osnDxNneh#{k_BDWv(KaAj%1bCPJ*R>X=H6_xO(wz zk4%S^)ZI+-oYd|{N5|a{-%7lgjiYIMN*Kd-uKV!Pi9a==~o^H*}4M!K=AuJ187{&!2THI2*3vq**We~MmnHVX`C zJ3kf-H|6FZ&J`t=bGg}%UHW3$lR9{`9(o)x+>8`^Zpy(&o8&IAGyucM!-(|hGXD&6 z#0b-K8u<6UiuulCFe^G*WI?DEnwes{O%Zc1&905A^@$3GZpO;eCAXGX6<6%)1#e(~ ziCzV1(@xvtBSB+;NqXJw6F+4|+_a-p#Z%4~BY|p%6N8b~D7~13QVD-flqmbsW#T7H zgTGjIXN|Da6>R2HsiWAiOLUlR^oOIPTQ%_NwQf6%6f>M^OmjKxL2pb(*CiJ5)OD{H z^||ex{0VE`dwJ_DBT}*|yY=rQokc-Rhw@HflnR2{=%M*s)2KoBS6W}!9vrSsQQWp>%&EzM=$itK2dL`%gfO; zHPX5kb}VL$tu9@>ui*7@76v6L>`0PA)=9j_h7>hko$){ar zKWn`}hcp3x()?Q_*++1K*JTd&y;XL9r$y8#L=QOEZVheLi}-T*Cg{T`Yzv2K_X zH(?%-nRg>JdgdOw&?1JPMBAa@Vvec$VolDd9r&hOp9Hj52Z zzO#e}`}B1l(rdyIQo=qQvmEqH5M2^XMvV^Z!zrQl?|Q!RHTm&3wj=wXCWiS8RnJpb zONJ&9XEe7#q9pqJa$&r5l5rZHY@xs_I{rG^D}L9hCkR~Kr{7A^PG^5yoqF!JzH+_K zOyg>!;L@BBzn-q(D$p6Cux3S<@OoIzuA%DQ33qSu?yxH2yCOT66 zDnQ*=o$KZhd}sx?!~2SP#v39%6=OW8ir{Fi2F)>zfuC|3Bs>d-BtcoY2#U}BhC>US zPAzTih~k7&OndUVk{3Abj4_`|X=}FbYrC3XsY~9QeLnxj$C8$cZi2%=zDb)Q_5PtV@v2yuX`Nkga*_Dj`E_IiqT3jFvP3ei(o5 zjEJJh(!or;06pr@`xt+G#Qk1|RgUpd#_ws2XS#&w@=dX#$I1v5&pu7AxMAlFoAG#K zR|0erOp~1NFPJWPoE2RqzNo=h)S1bSha?-N<9n*2_slK&XVVk(3As9E%&+R9sNiEQ z1agTxP2L5n$R$vjv$9Wvxz=XRGT z%Rs{2qQug@v?2=MNOv=XPL=_NsH~a!7?B64J0dc`cyeBXv#NnS<*SDLE?8@vxe3@M zu+G%?wMgPHpyLtxNOXWW< z$&r;TjdCu`JHXjLm(RUYBC%(2wm&Pj%V^IHB98U8nV;mPBJV9qe9RKetAQU8{A85P z|KovhJ>_6kkH|ef2I@`=Wf8X12R1Tv^~|a${5+*b(Wgmk3x}LK1tofc83l7d@WrGzr1fk`Ig| zJ9qr_tZ$kPxbN@lOv+9Xgc48I_Q=G+^IDAwdPYA&aD(6NVW>Ii!df@uEJMT^2&_($ zjx>#PTqY@?JI~mrYT^m)BCF!4$8tCXXmvVr3L=|0RO}WUO{XK7glBVR4mXvT=62Hxwo9Hur@5ye03nUPfp@vedVX)44VZaC%{-VvyeT_!FTZnlX=ZQRwIcef+EuVS8q&iE-Q^ zR;UG8gC<>s$I2`y@C2Wt;YN~TpGI4~_RYyj#53@^rZ;g0juwdop^CTp{ zF>j*yw9QbS$fUwPlvFkoiR`@msPKMdL;{DL67xW7?0n2-8oH}?^j0MN7k}T6N(r^5 zlO(Vb?pEBAi`+f-O@w*`$ZdwxRzFxKO$zQ@ooDXs9)^1DNbQ-hTt}{f`2mjhe2EdwJ0EiK6&$92^@O7XW3cs z{*I+);>&uY_2m1Uq8gbH6x9mN+I-h5EtJBUz0UE=-VSV zKeQ`sfRFs}L3#!qGZ4Dfx+#O~q)jep-oF!0%V(+Fual_XCDca1OLTy3FNOXLuSpCm zPC9qVAoq)5-WMccJdZSg4*H*(@5bhr@Y{Dm2Kny1&{RMNwSR)_8(qP|GHtvrc z4dls7E%TwliIe$i&v;WP>8PG+FmL2w{CE$~{Pp-~^Yv<9_|nJobyu*O&NGRRtbI80 zyEbc}yX4$B?kj26=M7Ikp166|qmJ?8?}3tCf{`__nAF*Uu+hZcGj;n&%}{CCFW16^ z8n01}(%tNZ=0h)g2OZ){K*!VGmj;Xp?p!=n-dsT;4G-!6O!r$ICj_&qCNA(0^(l>P zr0xZ&rBi+@(~8qvQy(@_MMtdZhx#V>;Y)bhnDf}Ic8CIngm4G+C@|=={ zKk|>15*>v&J{(sQ8d2TL%4yfyhC%ksw8rX?@}!|A(-AHzWeIR=Ts9Ee4jtX(U|#2G z63=I-9zVxQU;O@&f{)qdsM~rfMgC{92FV2(rpJcr6ed==O7hQ5v|)!0tCkG%IOx<)fF+`&#~CjmQF5p38Ge-c*3UsS zVru))5RU&qc**U7hHkdxT-QuDCs5nap8zCMZv5X$?-N->L ze{F*Xf`>)HFU8M_T|{r(EiQ+e?KqK;T-Kz^_bj!$s)4%@jIJW5OImMr3S50RXw8~O zHM3h#G;$V({$MVo)WR__#Om6$?B1bpl-T9a&Zf&}cc@y=fai+L+y(k08s9J-yR)@f zJRQUtQJb#5S1WlxKw80?VDM^bsxG!p0|AP7=;ll;vg7!Rnbb}nDff{sLKpHrQtghA zWN6IjaKSsNC__V2GI7%7Fs)0{jWwC#OxJ!k)Z66_sNu4BwUXPt=r6)R=}~7k!;$t5 z-b>%h@JJE1|i?3J{qPsA%DntUwt~J|72DDTrK|e32nu9&Fjbs@e;M{eDSAk@ z#Flh5;hrG=1J<<-TKjbHj1grrvV`z(<0BM5^q7|%t0Shac5(vFdhTX* zH*!j54Y1@I46)9@sdT0}JjwhW?y^e`lf59!6ra>+Hk}bi0p_B68D7vGzU~F*UQM(Z zhz2QW)b(rn&_9oQcLa2~^fxA?;hW>-Vme3ZNsi~GVixou)pl;@;#6y0-YUxjF*4MX z@EHG7R?u%Ydw)V*(%pr_TIO$ToPBrB-EZM8KLYY9{ET6OpX{0FYlxZU!v{}%5}Cx1 z2a%PH5xiC>CE^PPXlA^$;i&rhA=vRUlgPtF+XLrpV1~_>6DGWT2t6+wp76XrY`8gG z{*_{9Pzv;SuOWGan%=Q~h+fFs1gda4CQviE8*URypYh}1kv-F$1#DC=#-dv0VGS@WF$;SgpT`f zcI|0!^3?s$Tjs<@yg_?+w;|Uoh@fIWJ|rjo&^IoqW7VsLsW_u4TvCSzT6WNLi#p0v zswv}nyZ7Hnq@Sh4s9gm(R{U*NLrvA%2WK!uo8B+t5CSc{BJh=n(Y8ytUF=9IfAXDS zDB7v~vD)vreJJqEnorVs>bxNlC%FZ8weJeh`b3fZqL3=LDx}q>3acu|hgkg5+AwjX zSctgOoWgYYSb}a^q1wx%j54v-imyyxb4Eg~&vB?S#fcvece{ES>OWv$7YrMDod>*W z!u1f0gpdssaUPn}b{6c&cwP#Yz z6DL@g^URWHJ7nVQWrf(58~q)QY7xku>}IY$pUa&)il$m>%CD53vj#y_iuy0Zw%p9N zIs9K7Ww%G~7e#Owgm)2B@`WqC5=?GbrsD&uh$XRg3UIv7-H!cPpH&*`gGTvy*`Ju0 z)N1u)&tdd=Y2Szzm?yq{w95C!n)QnnmH$feGn}tZBJBExsw)q}8R;G?y*4!Zz^^k@ z%l*TUO_bxCe`1mcPPjd-*vG8SmD16l_EDqXQ~M68oXX8H;!{?VuK~>wtsv*LL6|+m z8xvm=;6vc!rGMjJIaisEgoxJ>0Ka^L2YjKxBYDqxX-UGnNODoH|Mw&IH)VICO$BI{ zhF;{2s#tm`do*_)uD`UL2!?g=ZF7zo_b>y0VXvN8%B^(D;;h^7?fKM~s_Lo~0tx;( ze(90|C-M&TOH9W;&&lNWEo_O-%xO_rOLAGOEQKN4eh`AL<86$2&F z=*hVx9aB0JKQE%ss}uwT9W8>%gOa-c0Z(4*t()?P+W~b>;Bxq_jXGi=<`&xl?A;4^ zjmBckCbOC>Eup6~{AKaJXG#gLB3q{G5aP*Phhz0eSma~L1=8vTe^~i$^8o8HMiFR zTz6p|%b(t!Rv%%Td!qDVx6c@>q%PKzh>TqSu0%@OI!h_>N>wTV_Q5mra&d3g0YGZk z)rW2BUb$A;rD1YdD#w}=pAG};UHF>W`+%*yNZo8e^n_5Z91O)KqIQ#dZ?YRK&8%${ z_h&l4Uw<@8fo@EB`2c%Kr~#J*Zqu=T_h&8xJwW)ygXO0F9M5#^gp1K?XW^I_zD=m! zjnV-ZKS?RQL#*iw zeX29LQaML?@3jCR$g<4E5xi0V>QIPV#PiT1+Y^_&s9bJroEkV}SRZf`AONxv4%ZoG z8Z{D2U&nzg2Y|CEr{|`*tw{EvBYeC7w2+PsYTh3Zr6ivhiGY`zW~51(*%T(OD04Z zeYk~PiZpJ0jfrUbijWQ_nO4}eRluaJWUv1LWmY)T`69uKQ!Jv($YFX_^U9SZ zz%)3jF;+(>(9~ZeBT?QFT2=RJ;7bnzObDGe0Oo}RotRUb+>uMS!q%ObmwNIZE1895 zS#;|wVn#C@{>t_5^%UdglqBGN**~Nq%Na3{9%i|rX1Fp;LJpkiO0zmsfN(=lJKldwP-YhJzhXUU&-F8$hp6d&?zMeB_SU#(7 zc!q5%>BCZgAX}i{cKZX%^VacKPJpeDVrJ~2^zO|v$u(9?*DY$-i^O>bmTLlhkJ2!tXm%CSWNJie4Q?5iCJTrIb}Ki#6u$n(wsU;Y zyNVoy@O>cpo*D|@#$twsl1m8(MlkocwmJT1Mjz<@##&Jxfb=4h8?fa4;_JYfH?B5-IzAH%%QK(R8~FkV+Uz>Zd1J5eFqy;V zuYXPg4~GWZ3YVMX!7l#StJw$)Akwou25|Yn*eiNfD)uWA(`8ryrbj5i#dRh80lC9F zM+Lhl9OOu(RHA62M|NniX3-V~ORuZmCcUls-HdIRh}gA{-)*L7(nL!M+Kgf{h$LHA z?h!RfptsK*qs_y9?~~yBB6Oh5OZI|9a+RR0 z&A<69rzOfAX{_13K1Z3-l2m<>0O(E%tretSK;5^^Y1)&y?D5`fSPpPN(Nis8nX_)c z0dEkg$WLRy#;)3LfjG+Cl(QAI&D-?td29wo3A$xvN&&qe(`-i5F$ztyLrxNVM%hy- z+00`Ph034#C*Ke^jW46cjFidKSYyhwl+=tYzZ1#*#_2wct!S4dY|oDwS_Zt%0aCrK ziR9U*hek?ZX1)$&q5$MNA(kEDUg}xxGB}HsQ+B)j{5bRh#U4q95xdU46=lhu^?S&A znidbXf)H3o)qO4bSj7(vk9d!@5yZ67!MQQYd77bZloDIihD6Nnqe{BGCs@T~cgJwy zKl%`Se1gQ~mvKisjrD8bWF{W6#dsWL5VZvDC4imD{<{K1odImJ0Z+a?PbFj5$X;Gw z8tz!VUm>+#AsPUBKqic2v`l-X4&*^k{b5v5T#r_2s8HQ_i2vz#+eFrZ7>FDblmIln zRJ2FD=*>MI`X0*f(Ou<{mk z>lekkZ&+{67WG#o!T->~urT5M>H)zdgSMU5p_Z!BErf++=vo(HuQDRS!YayrDWJ(G z^1_iu=Ifdlx(s~G7$`c5CDc2A95-Nq?w4IMwQdjisq|*1C{}gx5bGl*!rKMZ&>-46 z2kQyBaHkxiZd#nYy&;!ftXcI^IOXkxJZOxJS}nfwcWQt?(3n{Bd+gK?*d@^?PxGk` z0cU?h&WUtw(~H)vQ~XV0Gjze)4*Jgh*}UHFDlwmLL;^`cIGgu!pZ|sXna`szeRAs2Z_@TtP9<1yI{}koO;;~8OKe+(< zh2&3YE>_QpL#ucoeu3eSO)Z6=7IzvJmNy~Jwl^sWsyX5w6T&WknG%DXF1rrtzsv^Rs3W%x9&@(z@>ALAfG}F3lBBKH}9-B+UD z8vCcepm*E1pBFHv(N6a4+qPY z-o(~KNTCeT{H8z-rIKLcQ;`#Y9<; zt|k=&TsqsVYte#^zuhpDB`%t?Qus!v-^x4yVz+1wMnK z+sCR-PCZL~Kd8H_St=)5E%_h6Ff}EqTd8>K5WQO@c-x2UF(H&RI)mdKo3U(e8#?uB zesoCK7r?e<6TnfOorcN>%*7{Ep1Xv3c5Nt*7e+c{Jc*ng(1~7NpQWuJ$vJ11+rPaP zi3$UTDz8pNZKJN(-bSf@6$_T$WPpd?`RMl4ox{TKgZCCIcd_z3C?0Nsd`gOpkZ5$y z@JaI6gW)YR;`P!af)&K|5-2quG$JtG2)vQunhNMw;%iYtxf}5JAp(&>@JKLR%1O5# zI5?8!L@ayJG!p7>pIbjckVd7f!ByP&E901*C;G86?-fH6xh@$=3KW6HF-XMg(nD!_ zS2n^S!1v~2lVbBM?qT(?--@@UF>N|7ClU4H zE8ZO4#8emPROz&KF5e+4Yee|P_;4XhcF!T)Jkw_SL2oT>Tj}#a0{s#eFLbYg92X?&9rt+-3Llist<1oJC!8y2` zXs?=pzpZY%sD1)zc-PJkV|IfVZ}8v5@Pf315?0zCqIe(l!ghj%sm*Myxne_JWCDlg z?%)6rAH3#`rTVWIgZOc7yy8pWrg_)yyP30`(!r@C*n@vrg;K8>1wd)7M+lv}Fo~ZI za{;wea8RCH)x_^6DSS0*wMGyr`{gwZE!>0cf$)9D5U~|@br$imUFT#}nu9}L|2PMS z4E_Q?GsS0cRQhr&KN916Xv0Ad=W`hL!BBnskZJi{peYBV)?X$Rp&}be0|n_!#eg05 zGg5BipPvxJ{_wNoN8Ra~lzl+msXHYWA)?YPyU=Y2zhC?N_lMZ&_tq$1r`_**uxLvM zc+kq~;Zh#^SL9#2EOCg%haz*6f{vse+AZf7sZ~{uRGdvb0_+9?^kJZnAKB+Kpe^B~iK7I6u;fM`s>zT6o-i+RjEa*Cp*98)s1d7j#FLS-1qA$DEk zI(^}95}j4^6FlRTC~ttQ9+DcRSZ{G83uDWDVp_~&z?8>CSdNF38Gq--MIOh3;$sd` z8>>ruBCrj8d0RW?mCb{ZgIMa0A1e5tg?ZYR1Eo~E^E6JRHv>w1bF@^zx=Cv33Abn< z4Hbo2J=5maPNKuM>>qw4?(BZw>nJ-N1WMU;X<=)hLTxJB*}6G}85kK)Ey@#0*RMtY z;;Lk%jJsGzQKN0nsPev!(s$u-@0e<*f6d)0Sk@B#VCqi7yMLyjbO$U*EUEI!hw_sW zTV1~dRsA%CeD?RC#a<_f^Mo=!;c)G((J#xb4 zHFR2?MV67<*XIYfz=49Za4NlPvL{;j>uToN^>Zy8ri@a=nk?7Dv2=+praxVxKTj+* z(`{AUFNvSG_GzV!(FMleXR)Y^O)O1@vTzz?ipl4>K{tCdvZL;m+&;gjLSons7^ zhnFum*tvXkN=<EliXu%<1r>u1{UOKG$`%Ig6!l!kn-TnIsoy#RADT?=XB{OMPN z_S=NSr4#u}sK!&~!DDM>Fv7kPqJ&W`x97TzYF;L;+{bH%hg>90A`hXLWe%Sp$cES; z?JTb>uxt38ddLccdv4VMXAh2*W(9CkSL4W{MMK-ULJZa$gr#v*b8Xoaa15tgXR#?v z<8zFkYsIwZswXJyWEBdaKn^p8J26VKi9c_oYldqZ!l48cXvnl&S4PR6e&6R3`Re(3 zic9s5?c%CaBj6;BX!=o;{e&1u?rw_c7f=+3&cq{Lb}4*f+)hK90G<+Jw{5PQcd5eo zb8Xs)zR*7JBm}J<6T#*Z1Ss4)vorC*7itJxx$ot38~pft|1J$WK2f{s*?cAApGW$R zt^fRYm^%>^mtxwZ&cn93`m{jD%s7K-3(uDym5J%}&sB!inYJ=Vo>E~`=_ zGuG^+-+e&H7v862$RC{&6?u&d5#FT#GE5&s8J4N_f5n|;Sd>xM?*#;rM(GZb1_z~6 z8bL~wZd4GEGUy?sYbcSHmPXQ%k`$0eT1t?R7(f_gh&lVl=lOK5>zr@z>sKz}zW2WO z-fOS*U%x+IAcDV@UK58D`P;IrS8_=L)rfTR`|G!vLx#9XKoP8SXB=2*z zi$@w~beJdN(66<=+-DLkz%B+#;vDVlhR%Fux!I|AXMSaxzsq&<&?~Y9#Ms^6oYVIg@QJu{yD& z%h?Fy0d0Hmr)X*`BQi*JX8O8kRlg@W>06eWaR-p*FgwkqUw@G&xiV(SeInRf(=({- zsaTziUp5G(fJh!#f+bXBZ`{Q5_^^MGY!&6(f_g=3H*7UTe7w=E)_OQzn zYM{EbrD(cFgY=gE%BC z?V)erC}(X^Jb#c0MqB6FpDkOWf-nO$i)J}yfpL-($170)3)@S z#z@Np%Tk@n$dSJ_R7hkt_2h38(JzqH8xWF4$5Scj-C15e_3EW=9~$PojzPZjQvgCI zAmjP+=ePXj2t0i2ceLz=<*@YdfaT`sH(SJM$-f!%D&xooH7n>^SD&O8ijgDs^GVU} z*F~Oev?BOL8bg+&-Sw6yD26f|Go>lq3;t@8pFOsy9ofjoM~~~&6gftl)x|4{v^AS> zzmZJ&sqS1dM@MnP&+$1ohQxTX$I`Nl>zzu4_s^=9Uh3)44-#R|nehDV%j3Js%kICr z*OOdl@7y>c(EP%<=V5^Q*wU!q+4Hp6RMm=VOP{T?m#Zx(#R)%Q;()nf)skNJS=J0` z)d~qK?{EcEac9*xu=GM(kVnL{ftX|3W1Nvp4{_N!9hc|ezSjYZn+4a}1D5jm?y1*5 zio_M1ZoSTcy(ybw>Kk6nDX=#wcqePv9RFT+k4NL>q8{g$mBHDgUs9v(Hf8sC*B?;v zJh^8B<09d6iAlEM&X{u)c*z6agN}lPhq9x^pw>WRJH6gPRzX+!icg+uA4$ldwD^-P zMjhKPpSB+`;@%sO)IaJT%-N-jn~^Ukrz238=5S&)Anqi;%Bop;&%}T zPB~6@K#V@^-p5xayL~3&?6Mlm@$QV3#cNdMqO>-h#~lw2PmRfy z7o^kEhK2^P;bRVxZVl;yE(0IOtB*M>MOeywZVWWhu&@6TH@2(0DqTUKLpc;vg*VtWP!pNUYW5RZ=by@^$ zcP3tBucp78xMsJj#e3^}gAyE(&Yq8fjgvS#-fGxSBcgP3FBPn;5X{Z9voY0*teE&} z_mata;!^5ka?eOcP%k-*qvkSvQ;SQgfwh(R(2__+j+>yc$K2ItZ|zBxdi$EUJEiq> zk)rO~Cad;Uzt6O7oko8e4Kxw~PwE@2TdG6jo+3IRA{y3i;=(JEA`2topUm+^MAMa0 z;oz$;-Db&8V5h#!^LUc~)Pc@th+2iSq&-jKTX?!kg}!V=N_mLm?bE|0BeyqSyk;uKm;Y}7M^;_)B8n>0m+`dOC*ZkDK}!3{3l!W^`B8PE5AxoU+)^tG9}) z$P7n8PW*1!cy`uMUB{%(;-LHpOe$VHcg4`ivhDUd-buT6&_s~TJ$d`eS9CS?@2}>5 zr4y%}Gc+hg4Y824^$Sj_%q!#e=zZf}oe#-H=GTZ#Ba{YBl4|0NX$u(OieW6U>xUKuvG#){FX}J{ahH@*JJ{y`t zvyY@nB;LCF@p|ijkumLI%bOWA#ztCjsS-!FNBysXzgd5c`Nprz$-~|#a8GEws7Ibr~reyzx%D)<(Pa#%$cY@ zPn+mL&GL#Ls)wXj8k5|jNalBt>6k8jP=TTp@J?C2k>79BIc;N6G`Ct;pNO}bZ{Q$u zqmq5NAwx8RsX#q$kIaoB_&9DD$d_bIb#^ip1y^5k^kXv295}V?5+oH@jeP9LY}AAf z#>skFj$R0Q^n{u9ANGzSv^GU5uk5lf*G6U7>ePA1ate(FlFW~yVgvP1TuDi7WsVP8l5AOdd$VxgJV~>5fL~K;UXUJX|1&71HhUn7XPcH2Hmon; zC&jc|FT#VlBWBefqtZwhL-+{Wj{p?lUg~#^RLys|LSwwX#}>PZ(GXtO?;3k9SI_Gk zPSl~;IpKSNMJPs2$e>SGt6M};faSecF7xV#S9G%njbfkb{ft|l5+1h*G3GnW_o6jj-; zeevwq9Ih)q%F@B3U;d1_+qx<_;4pMwj;pSk+Mz ze(QifYiFTI9Y?3xU+if=&Pds-rrO)_l>42hq75>d8A9xaB<=)a9aK3iV~l~w&h>Fr zR&Uza(~1IkS!0<}b>_%`r;43Sm{-1K*;TBGd`tS3neyzs+a|;EC8LLak1aO2RhwJx zd{vnmy|P*?hGBD2SKz1Z-+WsY@lG-;r6YX^GgJ52TR8ir#AMKi~q zI$qA;JrLx+6&5S``0Q7+1(WGWt1~<`aZ~9??lOIBx zW|$xz-%hBJGwijry0=(yU-u>ld-RbP>1e=SPcYxh!k9F@Ph%GE)6O)LFoofsNhi&E zjJK7M%4Xiz$TZFpXpD^-36W`sWO<6#FgDTL`x&@?&$x>O{DM43yf%qz&`^z?6i=n% z5stp`^RoS6o5h@`r+Z~x#NRY|#>2xrB-IO~e0&MLc;g5&%v?5YTFW>f&edN+^9iQ-z3-IvY&(wxvA@bIZDHg? zh-4_GG2-bb=G_^ylQ`=|-i&$uXfL}i29kM%w1);#`G4A??GJNoX_e`^q(;VN>FEOVEy=nUpuVrjk2HqLhPQz@if z-^wz<>(1L5!ICb$=~i`5b0s1dNOUlDOK;^Fb)atDkllQB$tZ*V3!M0BxXlDhN&CJP zHi1XQa?$1iPIYWgp~Rb>dAjK-I%~Yfp7Mb|XweyN8@M{(jY04Gic8gaZOKBJDcG4X zTkN9#SO778m6{Gg=y&;$tgXpfofwY+?|Rl}eo{mEX0xS*ig^br+M~uNL{N$gl&uwW zZe161n@xX0p4HAYil?F>P93f_UAuF97<|v3gU2>RpNOBBjFeKbxN$Y>-0;o^OiPs+ zPfL7C0sp(4Yq>W3V??TDnmkVIN(wh4IM7KI(ZQSL9GCz z;~6_pGsyM@U^SEFMu26(rZxdX$84GPyW|YvkD%Yp_ut?Q;a!&DNS4lJAwtdO5ClHY`!im?x^lNF!6hczw+5o0QEVlkCiDI`?#^#T>W z5jgaoj9eG9+fER9;nfBl{2Ds2^dGVN!r0MFSn4S#%4D8?g`cXyJ)0MEqlD7TfJ0P9 z1~9KVGbAQNdINdN4E&){fhS-z{?h~;gF@~CTy#-I*n5Oxnjt+KJZsfQpnU=(Ab^8C zX~jw(pC7f#OQM?o+^qE5cH{<*mP3$_=HANR<+yBLFn3!zmTsrBt!J|#4Y z4%AEE`2s9XhRmk%P}+}00`Yf{^JE(2RLHE4qKfwHQDPC922OA{8$NXj-V_EO0XOz+ z0Xt{;oJYcSHT9`>EGW5-^x)Z`1Ic}@fMs-(oqPLLZmoGh4I)#ydA@)>fhS^5hp3=GBTPDRZ+xvHM%JEDoP~%fSad{mJYBhJqK%9cWPaCr35IO-d<8YS7OBsTU!o4GIbX9+y?fIn`GU~ z5Lxd0fcmTaX@|U4)#3X^k;LXX;6Lq1>gjU4y`2aEcJQ1SVxI?-XcTffN*auit=Ov#T1Tc>{h?$j` zKDYIx<0Vu8Az$D$m)@=&bKMepjr_I(D7)>h0F!Kc0`x;BpJ)=QH!B*opM#~Q?GG?| zf=npD@ClWuUT$cqiCY7ckKKi5*AIMagEUit8J{O0#t{;XcB8BFLt=5UcUK~HJzIc? z$*aS64-3>*`-#v3KHG@;xto^!W~q6E@G z1uyZGTn{G*cP|$_-2iLIY8YrshRi+sOCoevf?JPkr1ZjDasX@x)xX}Fl2P(M zMyiNp6dAmJKNVQN-Ey*XJC8G0%|FVn&QpO#5Sa_^QfN946btA^qDENI^#w!W9(X2ytDqKN2~F z1DfELNPc(B-Fy{>dCRkL%ll_Mwu0ki#dN_3@&wyY0YfXqmc+ITPq~h+HZWm2S7Ve= z$w?Vr0Vb0CP3I9rCcAVuqkz5E)m&7-S00lt6^G0W&o-V!$4*HE2zA2L#mL9GSWYE14N z1=BAu1M*Y#pFwHz&S6QQ?~qlU6(`01jl)VP6R5~g{(&`ZJr+*?fthGv(ZBOGV3CVK zl7C-Z7%GS@FoK?i0|akLf+d}*O*$)__>!Q1S_>`sAOn3yE7%l5@9 z_k5ID=&_@6rk?Pz?9<+C0fc_OA!E%~CzcFgmeJvYm{-Oz{RAInDwv8{f(%r5&qOQS zVz?+9VJrSnk$GLyzs?QFby9z9XuoF2H5pZ^)4WlG$_}VeI6)l91%)8kW3yW4O&{@4 ziM{&Ix1)$*RnpHd2>TG-Gj62J?@FdvHTCpS zI^FMWb=%8CX0usBs$M;mxi?MNgHpmn|0we^>2N6xhw__m-0S!jTvQM$c^#MMNn+F2LB#Awu(?(r)$KqwXOFhM@I**^X{Q&7tX{@7O+c^6F6I* zV<0yjL_$%9Cx!cCOA`Brb%e)u+{hCZ$sM2zCCkT|R>;ryjQA2!!lt`ZRE2k~L;Z=C z^u)C+1ZUSxwh{jBbBOqw^X!KJZf^Vgg`;s?kER8eaT^+>;+?Iv>Npal6Fb*P-8OYP zz8uA99oE-*jA6Fi{SrNQUJIdB6CRzTm5;Q{JVb~lZq<1PF4yEs3^d3yYwWc$l_lm0vp@u>$~vqbiDHS~94rFTBDi^jzZ;tMQj zUisnDdPXq1I!!4!Lf*ydt)HO(dxqP4?M&}3hYjC#?T*?bvx0Xa&1z$}qRa})N6AjY zbAPh5I1;-a5K1bi&Aw36?VCeU`>$6v%LQC!(7BA{PS}lpDr0N!!s?)y(Qm$RrAL&* zg=H$c=&wdx*SyEe{a$YJKQ87?S1Q0>GW9;pK-jqrsnl}1y>Pyb+BSWdQB)q=$Z&Cx zJ5eB3>lx6FyDL$-6#KRc{>pXTqaPpPMm`9a`Jf!$!y^$u#U_o$X$HRGtg+JE-4$Ze ztI}T+2meCcQl9Jz66CWNf?ICxRUp6(Yx*`mOL-EKJM}BsY$943SrUbAYwaHC`)p;6s%&^X^U*RfssLI_aNd` z>DR^adAm94tSZxBjGPsa@KD!!lnxSp-|_wZAe(_`+0(^>O|t_g5}L?e$rc|p_nSh* zr0=MfHnr06Vm~Wo$fcY{jkR2UU8i_MLsatKU5xqC5j*BFeDd9%a7id3h-tK24LxmW z=up6hsr~(n=ES~lh#Kxe=|)7}FA9r!U5_CP*!a%Y!@SM{{m4!|^bt5ORF62y$-_Tz zROXg`3zVxt5v5)$mp=cy9QM(({d(4Lw$!DO7xD#X@PjI&XmM5XrcaKKLcOK~TCU9)KXHQkBC>f%s#<*ng91gyQJk6m;Zn-z z2hML5#UJhF?n?KRHVA8BtjRkSV5Dz&PP%FEgl}?{y$_A_?@J5Ii1<(|{Lq5PQ_)My z{DDaU{0wq&_tZU4`8DH+N4Q!0VX;YpL3W`OIIFV@8XOcatX*Qd8fnR9J0CKx$khvE ztS3h$ax`Eoipvjf(Bnsae-KftGpSmVp?O}bYv1Em?dX?xG3RwwwHyZbf@V=R!7VZ6 zAu)I*L-#dBmvfTb#UCl;R7tVD9Hxs6L@l_`kO%J~1+oy}m!x_~PwyvdxDm`zCZJ+kfzMCLG_QOL; z0s3X*ZSD@aviYd-_M<*^pNMTz;J%*9T5S2Rv0y@uib!Ai>7@DV?N$bbhS@wv;qtRg z;9}eICzBf-3OfvIy+N*|l{m%=^yJW<)?V`FOajt~hEX6XU|;s|QDMXZU>i(u)yd->yF&8ix!Cmt>9mG<8mJqGli%k{cd?c)wxg zm(Ye+vj28+6Cj6+{<5159uBtA_&aU}=x=W_{+szPkn}Q1S0!t!W`8h}1{iJ4EY9sW zyO#p*=Y3a*@QzuL0s*4pol@87HarRba>ca5l(h5aR|79wopflH&B!pwaYrRPPL0hF;ZKsMcCWq#S&#b?%) zw5#&JO9)4p=}@Qydisu5kg@1xSm4yT&@`ilfa*NtnZiZ9`P#Z1bthn)Y=Z{sW4Q9` z*L}1;e|Lbv+)>L>rtfYr>sBiR+p`a!dceJ=QJW~7>MfaSH=rNGB!RVeSW-GmQI_{_ z-IxdYo}vmwRig3}6W~GjI^>>J-h;5iwXRwQvsop<+Zk3_CO#%>~-)IMZW|q z9_lAV`h=1z4VFyxP};f4Za-Hpn)SJV@4464fOXP#AqwM~}QLNi*k1N7KBSzz8P++U~ah>6y`5HvGJVvTXC3`q23aNja5_i+0A3 z0T^1kNy!d)NU$VXSsyB4xZMv^hkOgM&jKdN%r@>BrlsMUKx^(Ak-;1Y?b9GN`kCQ_ zB~d=K^Po5Lx6hy%1PZ;n6d028lV6{Mw9+f}RC`WRK0YspTv}nBV3A4vb0^#?3bCtXhSA@#d+5Jw< zyFeu|4D%H3cYf}X-V6Go#xC+poA@HSa^U}>!;0El$^gh@^ZLKAk2+|HpX1m412tB^ zgWWy&V^Fs$&3e(iY3_3CJ@u`suvG{j+Ag})2tm|gRPR3lwz&K>giOsq4*||D@mfHu z`7$t5Tw?fY`o@giS`az<7LQEAD*PET%)SSbzXJ$jikKx0k3L7;K%Fd}%`vb$@cktA zv+Cm#n9v183Aa@EB`9;o!e_eyfGbTxsuc?KFMOyCC_tbS(*9WO->`)cE_s2h$VXSL zgAL$+x2>2gmwdfhVOsM#{VI`meq?T609vir+R$OI)bi}NLT6H4viC^4s?8C=N+wQ7 z%mZ~IaGbAG#T1kTLs!BuS}AWe`k*yg&Wg!G)<}B?tUmhaQ~ghd9x8p_N|s~$i1$<~ z8?g|`{_TtH@o*@U)$y%DHv9X#lR1<>UQWHvbIGjfl-AQ<9d#fvqjP8}sBrADR*Q<| zhn{=?Vi%`+#GcSzwD<)w4h~sH554)JzV@%vtJs-6S*x?7NKnoVO7;M~dE@E)d~Gn9 zl>Mj4sYWH-m91xnjWZOxu; zM4c;>DEA4--{qRZtG7$L*rdXnc(>)zT5K3@W)G8%3wSKYQ89)vZ4%naks>(jljN0w0j*qZ*|6}^;#1C}dF zVi`}ovNEhh&OEw`zzqYMk)vihB!{$0jdv$ZkTs_XK-!?QnMSo{| zzDPljviLFK-MEyd6|O|6%LotWAYr|Lrb9s}3C`-+a3Z_o7MFf;kp?YTKjT4m=Lcn( z4jePq;lcTQZ2tJh{uC{s&>T+( zS$MwAhKD&y7FO%WC2gTOCRsRGEHR@gH-~^X*G)()twwr44*%QHfdY9makM4rh$w~Z zG-Uj$RiRw!olBUjO?cOoz0Y4jIKE$&7EeFCmVL6x{Ala46+U*RFn<3EXm(4AhcJL%z$#C!z5_qV0~$;Xc6Pr;@HE9ldY0 z8!K?qsH{MdPLr$Zzd+k(&%R&7(4QR7`=L{3TnVPFU41{YXq6bW1n%{r4i7f zK%vJN6@3|GVhz-agL7@2@3G>B+;Rrqi%L>I1i}@IL#NdCDB_BMb15EWrntGQGpT5Xpdd0 zcU>ChUge#n1v#LQ3EB-%b}zqFIBCR&0jc;NJ@11W`F>|xe@Qo_-FSp&K^8QmvCQL7 z;HHjVsqh1Rn&o?Qn4osrnLJz@M&rxB(0*$fEc$n;NgS@G{$y9*fEp=BJ}lc5OdU2$ z+6n(C`*^Xfoc?ze>LZhkXk+x|Gbo{01^$78md7jylRX6B-wnE3xPg#k0YRP@2$-2m z4!wG-&SV+^;29G4pEUn=i?mTc*?h1aCHyS3YiictKIf&l`e&qd8WX3;CNVmKUB<p9j;fNFRS?pC-cBBv41AtMiJx=p{FmTMx~tl1+(} z`3qjwqw++u&{p>~*UX~LpypgWRODLV z0{b<>VUSzj&e>HYWX>2YMf1>TZ853T2Oec!c{Fdp;eaRkuBOolPb| zPIaWKYJS)b(mHGcS~H z*lsavhaBf%x46?M;`r_}unJ4CM9pXFszFx-x0H72a%Ih6Y>EkE> zpn0q!FQf1K{>Tj3_wB%&U>C>GCjC?0g1-1Rlfs{EZEbRH+{ec_eJd-B;BHD;$Xvu= z``zxcM@tL3fT^s9vVG@hW`MKLMBbf`owvQui;}9U>WRjShh$+%$>l@|Tu0YEln|_{ z%Jm6BNo5FGcb$c`MH&Od$Vj3bT3oV)zzRuAGo?XJ`urL5|HWil43N20i2ue>S%ouM z@YUjkk-A)jEFqX>FH)>}@2{Ed_N4t_(SV$*JufyshTD1pJoA{9Jz7$Z*j4;>ZR}kn z?2xo{GYjOTx;o4XxQ>B$C=qygn5zas$!G!D;B`hD90O!)OJW#Gfbs9_cm79ldP}XH zL#;=I6@ux&?a*cktw-<9;HTPZ)3)dN#yZo0Ex9=`_U^@aLSWT-$M~+4puzg%NZ{Eq zCb^Wg;h&`yc^LQF4TfikGS9RmWJeSmB$9%Xm?f{owTuPG`+;jd z6ZSUytGY)<1@`4*8>uZy^T~Af_umytyB{U6&m}(PKQun*m-|y5Ob1weT7OD2487py zLpoRQWZN$uEre zd5;DKT_8h%rBVf4YZkzc<5s{i67xVAQvm=bxBC%7HUTl7VAblBy_}h-{rc}`RDd^?h*DcpUaGwX0s?r-`dMB< zHRJ7idpIFj^B-D{CmDadMc(%2FO^ii9P{9>GgQe)JVSb=h1B#)g@Zap+Rt=-d-GqH z=yKAXW^KFU6{f6zl{kxpl!e<%kOnq?lV6A<3`gcOz9db1NDgZ_ocEiOw(G>ZMaBX- zG;vgr^#t90h1m(`-K&rRGUnYv()qU7_0{3NX)?LO4Mzi`lj5y zlXMD&TLYIBEiZDx%!|#^iLyJasOqk}sJL`4SN;dABMXGxcCm^+ak*fhuamU)QeaEr zqkG|opQ&{ZvL7&M}-C_0-rD;ClQDt>?+MNZU6XP z5EvOB%Of+Nl+(p(H(n&T^zG%;O~1?034bt(bFps7Bd^oOQaz6?(US#aJ%ersYLNf! z!26AQP0HpUVO%~>3iM5BI%^ISMDpdI2RW}Blc38?3ha~JG{a?5U1xBCWfC0M0l?&tse>1LN9Pj;BJ2l)CJL2;VA@IWgTmHAEr9vVfsc! zhiP2q3dYxq*x6HN&x*bMTE&FrGs7KjA&E*C8LoIUsCJaFl*cWKXtPTz}!Ofc`PCiU7Ig)0z3ZHB&>z%!Eto%#;Eewpj8Kk~Cea^~^Dj{G)7 zYB%QSG$@IA3}J@x$b7259nyH!a9U5~8hoE{BwxP3tTRu#(hR8gM@<`4kE`D|<0JCK zU&feO9bLdPe+7nL^W^&0&_gCEgEqfV@1~eFtnTKAjPLyBlRYsmVo*^)?3oLTg7XF! zCyWRmrHE%Y97vlz;P8z$kzFuqd4BkHFOuXyCX7PqkDVf^E+gLNa3lmI3>3U_d;aoY z2OAaVS$Ivcs4o4x+KsF4HSfAI_exFtUq)s5l+K4FW3N5ceSTXaVC7r`(U?)wd*er@ z%9oOiK~(KR!xC0&zc*6EsHF321>@79J!U*H+iV5T z{{8dU(|+)CqRDl)euKWql=AM0IMEt|#=Uw)C11P>stZ606)>E`S>{mNnrQNwB3qu+ zn)f2I1^EKYZvG#mlL)o>J_fgGI~F!bFeUUaJ7T8bfnaG-8^Cd#aqY$c8N(Q}yx*}$B z!}`BssxfN^k^y&G(BL4VwfN7EAbPU_2q=x()H~{AaX$ zO*-9ihj-z2HM?!|molrwMXv9C)JZfY4)FG~cXNVBl!Ru24+E5``bEw>Ae^Kfwx`Kq z-{ZAndj%k;8&hMVZr61CoF9G{Nnn!;>L`v%U4ux`@MP8TxMm9cOb!Gg5%MVDFod4^ ziKn{DIIjn%m_*bO_}?V}P*Dk?HE2LRmwi)cTqYNQ?Z5m(SY7rN0~g z*rZ%@+SKgL|I*Jz@99~X{fBp|mq)#B4e1ET1)St0?y=L?HS0e{C+BnR-;jyNF^wC` z4{)zNDXnXn*Mv1uieDw)AoqRa-@SR#Ruox%yC&lXL&jl4q+jpaWoI!fLR&idUEh_U zM$;g#*x^EZmNa2W!vn9C!^GHdf1cmD=oc(Lcf$!Y`zh~8MnWd)&9}|1E8AK?x%;Ca zGLOdt$f)|5?KwsFh`f!3fyG4sLK!`-J>;Iw>Yny(ndb6yp2%(ho)M*(w`r`5+#C;i z5*en!G$r?&4Xq5i^ev-~_tvVPSkSM?_!g=DV9_qR!*hc-{TBoNzDegWNvYc&mwSpz zOyQ+VSjuU;Qb!QxJvHi5VJ9;gPI{fkWGnZ;ejCw_u=&? z9lKfdyXD|q`Z*%$sCnF#?1;G!gKIZzhCW3_IauM`{CR0k^%DJqfsw1U}o}YI`_@R z1)lpuBWk}XdBm-~HE<#1JPz$}pCXx2(dD9Z^LiG0XRLrXwOsa(@g-~xUnY+yH>Z35 z@%>X&lBEW(k=1+wQBG+1IUo?g8B8D2E*63Pg1@L&{Lz~1k5WVr3$azgpPY_6`PH$D zFI^6q{hUkaXm}cs^Ti<6HL_q_z=x|Q#5YF?-}ZuzGZ{FcthNW90w3?!tQJKwR|xmS z?=WMVezoQ8-1pz<%~7KnPWVcu+zg07fimB|z57j;^zv`@9UT*%ZbY|@gTZqA9`bL& z12S7A*e8-siQ%L*04C`SB~DJ~hX_o!tg$7Z(t8b)p@((KoDoXGXqEMXzc+4-bcD7? z<-~DQIXpiU=%7kZU`B zOk*Tl#f8(4ubEedj2;?!1VLgq=eYeuZ2TML)@?XCIs-D~`dkExaS}5xEUzGzg!|E=|OQ3;*B9P5U>4aHcc~pZDb#i_7z4hCe^QVbG+TB2rLI zOSzGy)c>&}ZFOaV<8v$s^}nWOdL|ui|F~<(X4sH6hy)G26*Hur^d~)~9NUh`q@iWB zsZ0iompCSscf|KHKqfzPyl>8(E}3AVp3xs+Dz?&|qGuW}m?rwJ{5bnG$MB5Yb3QO( zNYkC^Fh=3_L(F0eiOsZ~u-PuzrW0KNNpX+Xf!C_P=zHT~{EiF~zg7G5`&u2OflCU& za;t{1?v(B}BJ54Gp@Y~B6-%K&OwxFIGZYPdlzt`$T9fT8Y43pic`~B_8va;XZ**v+VeSJqgkxhzB4O=WO&2abr$`obtNu%YMy zwN?X#z5ldi+8XwMhvVe54zCqU?9SyZx9FVv{*_Z8g`mG-N*VHgN7ep0<~-x5Ev8?N zEYo#agH`uyI!dwMk|qLJz;_8Juhg_6CG+F_&{`d`;7+Uf(Aw!Er*&9F2vB@xfR(=Gh_wHzN4u+tYHc|nj(X^bk11Q@fHhTIQv6Ouuvm3^Wq}Ma$)9pGx3#kD@0<{5T@LysD zn-+g&gRF>UC|b=sTwlJOASB%)aQLE?$-rLxu78Nu#3Ei0_RQK(Di42_V19V7#~|Bk zQTzFPoX2C5r(_~6(vHj6W$rg%{PArzgFf1`)!4lTcx)F7Bp3g@@WkzB4}+sU$@IEC z#a5>Qndpac$3kXjc#tVW!=YL*b1TjBfP$2;3@P_jJt!K4-hn7AB=aPc8<{rxi%)c; zX}8taZ@B)i%HvO_ncIA-?POdcdANC zt12EXWKIG)nA5E5XqFWq`(5)dfd#+I94G)w4_iXr4yF=rSIMdorEVpPV%xcH-Hp`K zahsUOtKf$In>FR}t5WS^P|=)90IUn|QE+MHn`t^Z)@_CR@na?PdFQ04uhdp7D?Xa1pd8HN zS_A==9Q6uUPtc@ljZt;^pH80y4{zaD<9bERVeSBm*S{gua>r&Q{T&iz7&11Xjie`R z7XQL{{gbRk$i%xC)%0q0&P6ZcHj3^zaMV%YLwc73T?Jn672CLgD8V{&wI{Us%A7i*xY@=!aphiOY%JtGga#C*C^h4tgQCP zg(vr_oBZbQL~huK{o#WQRwlvLfQ$-2UNaCR2cm2=m?<)W^>H0`k)^TEKU$%u1Z4J- zNMM1n51*X?*IK(0o7W(W)ja_>()#jEpBBuppE{#9AAk+;z^>%Wc^k2i4QDke&jO&@ z5zFo^`5gA{te~8UoQm(9TT_ip*6Q8K-Kp6K6lJ})ho&{AZ{H1?jMlIsx!-L7ixx^=HPA*ooT^Kkn< z`$4*VAuR{&ROqSK>%$Ag(TMdEih;gw>(x+abvRE9NRBO!+ehK?+2t@})+2^gPOMK| zqkmMqTWd$`OkALRnoRzN*M*od#23mB8&hhZxx8ba&lw$qudvQa7p;As4y-EfE@sv< zn2ou`AvF)v`t7=B7yv)Gyc=OX<~}Ro(mC0yQTZ$o>3MxgCZM`X@Vxwuwf<|c(6I;` zHoA0USXb8u0qU`;&kheJh)0r2ox4HW0lBsYsg%7JiRX_X07a^pc~$oB)`b!^ZkIs{&kWzN-a753zuiNxIRX=p^+l-Vcn`oq#1m9=96ijqbT;M*8RZROkmzoAy@*}Bm zcLPF6_zkMEc8D@=;+tu>*R6`4t?u&DNb<{3i&ymNYVf|=%02e>WrtQ!n>q6MF@>$ki8(2e|T5*WeJ zLt+@{#uA&IbJwhhNeR0a0h^NTxOswQm<6Dnf`W?xFrT`n$8ac&Wdh4waoXJ%7k#L< zgHpk6E?+vgay8OkS95(jlfo+mmo29 ztkg+m#O?5!Vkh#eutQ(ei;mN!J@4R)_Pfdi_u{;#qt5s*OwVd&OQKQ?xn=!?vFlL? zk$&s2gt*XsWIdziO#YNgLACJxy8DOb7 zkLG=?3>PIk>zuHRE`4v&0Ru||`5&eHG<@>dLRj)!{Nq20`{M7lyPuNP-JV7!*TbUF zB3U45vPq>H*{Z@461iIJMiGQy;;(L-{VCDW(HD-4@t+U|iUfimeUXzQRYI+Gbk_~D zUuQ5PCpG^6$N!t%0G^+nxk(Z+Kmvn;hSt`k;38iqX_1qV9d3?S8iJqh0lPV&7w6k* zBD>q$+mK9MhZKYXgv-l)y@Q2%W(ohx!QgQ+u*c45p`g?3d~&*M=Brg`9dgkLnC6oy zLpS&LRe~O_8C_jnAH2R!zM4w`t#f3(JAD7~?!x^Q9o^`So~gAS0J*-yuzj^@443Io z>8Jav)HT5P(NgpN?b)V)|D{8jSi0QR4Nc50YLPw?4Y3q4W|C-(Avr>DtK7pQs4M z*KhsY{wk?7Cc4kt3|{HTjfbmrBaR=UHpwvTVx_0|qRDjzQGwN-O;~;5XXiA6Gikj- zZDj6L>KIgz_wNPX;G`{t*o3HpoukSB24mjb_uqU#F&X+D1!7<0-#zs#KY@_ry;?G>?jW{ICDMhqSA z5)Of`7h`oMhqLsK6a%|hR|il-;Z5zQAHgrUC%0#$aFIifJ6Fa-V08D{>Ak_hQYCMr zlTNO&K$U#%;%w;`f`BQxqMyC8D-MiEQAIt6=RY!#J5-lX@$$P!4JyvM3XKC_w=cZqY z5>1VRbP}HVlQXxnF9iO>VviZO@7O*GV?MoME(HlYX@|~8sOSxNr*5{)+>B#$GcE+2Kdo07M14GC&Cnz!stNi}Oh{&DTIkh~<(aR?%2u*Vr!h6(9TU+? zs+WmWno*R-KqP-QMQTzRIzYz>uj?4nr1pI$Ztbm&{KAJB(l=PikUGCs4w)v4CSqAstY=F+i0(^N_I@)9NRu);e z-K1)y&x5^PaWzP%RL;MMHi?;Si;p&lPgJJfhwhy2o(~4_PQG7Vk$Y$!F9@&T>D*R#mN=EYn^`XJze47JV27Amp|0!y z(2wBn8VyV#9^K<667!zV_nh-UC9=}E+{9heN!UwaU9GNlcxa6^-6(S;M`X8(I3<)g zNWMIp!!OsbUEBZDi|D!iTU9lML$WV{Dl*tcBZJ+<*nenUq&w`Bk48;=X%BWi{9qeZCTKq-;+w!~6%g^UBu zzpU^`l8Fe+n@DJ>;M5*_M1) zSS^_3ksi~m>!-{#!cSV~WLM`9L{OKWfS)Hxi_bBkkWnD7SFZv0vZ_zbHZalX#|Pd! zvs!1M=udm9m7Thxq&T8snd*G1tyyt?h5Ku#DOH{mK*Zu7)f(>=#+$;g(Q9axQT;$; zoUG=h{3`}ChbS?u4-SaGf$Da@Pex;-c1-x}@x0di0bOMY`>M@`;j=%CbdyYksvZs2 z1uFD1O0^D|pT{#ddeBebUr5ydQ-`NzHy?|P5FO_FS*=3KRP$mKu|I10EL?_Z2R$#? z>sKqP1P|4pcPG{11^4e3$P{4Hxnr6ZdP;T%t0IX@AToNz{j#D1&XdhB60AQi4C58e zL=)o+Ff_bLdN)d~sONkw2B0^NFgbgp;2vuiws4ICWoU1FXnnF%o^$NSb)hXXo=Z>ZWv8(mgG@azNKH=ScQ56ZTGu20{CF{8xY@p3ctHhhVG;(}|rQ?N+)E3u$&Vjx;|{2y`wv6T!6E zY_&xOPx2&!DNS#7GsNFrH~cN>m4YB?qEbMPYfhl;2BX4A9~ai-93D@T&I9<#zCV z!xTfG6x@;Z^_$tn!L%Ow`maRF)L}T5T&pvc0oyxi=w`3v${y_zKgaotR*c+n3kC`*kDSKnNZ%jgu3(JXrY5 zpfR8slRR#T22LL-_E>*YX~L(%iC=WEs#m*|Xlmp1V0419B8m55mBw+#11HxFWjmG*=;Lf>wy`g>EYJz#i3S0$J=3&x_6J` z*r^H*pPz`6zw^N?sM6T4qG6m4DXxRNp;EL{4=-;a2ciYdRDk0JbvMPY5S- zLMUSpmj8wb?fIP+OXi!qG)L=-!rA{Ff1I?=+V;Vqav?L~1>FfS zXHvg_m55C3CGb!L?C+6e3ps!Ti|Z_Qt)$P(k3jtf78eim{#gJP@Wf6?eV7}ErRZ^g zNJ)4y7Hzi(8rAfvH|cWfgBJPA1}rNq!T#TNjYhdN`A-tWEIXD@U$}>(?X~T;Vy%OT z?KJ)(xqU@-oXJ9egAB!qF;0(9p+3j>kVbuSXg{*VkxvgF_3;Qa=rIKO|K)+kS}(`H zAbe|Hwl9qg7_}dLFY&w5+HAASS_IPNx5`m}@c_FDec2JI9|tg8w9<=RvjuzNV8ZWU zptMiH&O5U{$GR;O^O*-I-TRgAf4tSgodHk8auN@cGE#Zi23jeB#R0> zpE74_CV6bi&&1f22bvgRyo4)=ddBP!V4$pA0$xraZT2X5RdSNCG zVj2Kxr=frua}GZpO2z-nbD_+7l+x#|`oN;v*F_oi}^S@&?a&Fg|-jYw8Y zJS+>f8ZNOkw?FLl;waf$e}$*K80}kW<}SPU92Y=RTycGKwXQFYRetNF&ffl^7t$B6Xq;1!V<2eDMT_=?kFqZ=s|XR z|N2%W-T*}X)ZcvugJC9~ils6%_vGY=ZkA}dDG+Ly=_5$igl8v=w^*I*vszlc>8=2e zLD*UOF7CCOKD6Y+8{O-c0)UhInhZt!co|B4h>~1cb1PMeE%@c3M%Z?8T9(L8DBPzxUg2w}D5STB&G{T*TFWX!A zb<3I>6L?9%FIb<%GKE#Q_=~;(LGnRza+bslJ4-YyL~iCxkbja%`JlMK_caa8BE;2)oVSKE zo%Q2rb@$Wp1Bv$No{lXag@AtC7P?J$iZso7z|DmpERbk;tC+V}Y+73rlM-`X%urEP z3}rQy2s~>^e0pdMU5$I_f%vnR1QPI1N!xwhAMGu*ZC= z&mSG-4Nx83-4>~KXQyj|D{T|XTrR!~IO%d5nR9AKgu3&*`Zs6BK9G4$9UC6PFH4d2 zO9Y9UR)lBll){={CbVC~g`Y#|cn&5wXBhe)F!g_mNi|RBu`YZhUs!9>o(Keoa!_43 zdtVsWX+BLgoQzFASR`e11vi+tM&J|d1ai)E_=!b4WzVGMo?u$`7qAiIrqxpJ9yZ-Q zUb|bWHh5H`8eWW$*!)k9~N1DVdFTC z<{>GG?t2Lw(7ffEkO|r=&nb_g-Ap~0fY)}f+p zH2_fqwk+;UR5Gh67X+g_@A-1n1$iRL)%!ikZeb0+W|MLKUWp8PCQPM5<4qaKieG1Cz6ChK~Y8(8*>|3^N?Xc9UqT7Yv z6T+Giioo134+H~{48ku1{xPX7Sy_tzapiQpw2W==2(fNL_9WIVIcVR556&b zQ+=#a;?>{DX;^_rA=)7ewbRZT0D@zX;Icb-afZCb!*Ng5ccM*XccNClMq=3ebae|C zJ^n`7d>0l|s4<8Z_|Ualm}&E9TSfBTFbqto^dlrp#x%E*vby zfpjahE+6|C?v~V;%8V)IjmPpHT(%tVwGW|;7gUV9Y;-8a9_d&1pXv8vt_VoH`dyRe z@`}EkKTM*CT_8Dg0dzMgiW((a{LGPQO3ksSh_* zAMvrnE^BT^zsdA3j-m~SwEEKU_tb6|x$x)_y4edG#aoG`ChUKn{!{N{ zqvIohW0)?$vc5Hu8jCF0E`Std#BV8#0#gdndQqEQRV0hwn&`KXV93I=C7W-HM2fO{ zlOyFM4C|a%0o1GLB+pA=Kdk0#wBvtCL__8t$DCRQC+{zc&1q#c1jjz4A{9znVMlm% zxZx~bZ5*v2@PRhPSRk_Tl+bWxOXuv+H8hzs6z%y5C}J+Z3Aq?#5~HZm6A35SdW`?E zKgrc=L5WqDrVZ5m!T1YRy|!?biKaml5`7?wEXD3%H}{IHo+DG9YH-P>hM(eDHMPFh;xSPr!fcO$a@u%m4NY%N!wIU`-^E4>Y zvGNnc23iWBdNi(fK#Yu?J_(1TS`)-aQ4egVhDWatX=&w)U`15SFb(`bjCztz3Il)w zSq%o;^&JoZq!ZwqpGL|KX592q=;#IYD;g_frXCgS*VH*I>UYInlT?n~nTt1#b`nJV z#}m5sd%Y=l!q(pKj_zQ$N&-PJO$5+&=;Yr>4+NZ^lQFOO4OGt)Lvv$80A9e3#nav> z+2{CAG#yM=BMYI3dYnGj4MJs4pUnmz+?|rIGP>-rr>`dbulv?$A$X5oYDy~hD5J%< zUC$idxd8N8wd2{&JPxIJ`x-O!2C6&-?gI7HQvUNKSHz@^&A2 zcDb3WstomCgDm}e>#I7q5Gt5_MiHjysZdosLZ~Ueb6kA&D{B}K-=b5sziBg(qgo0O zw-5QI79gkH1`$1Vs+nUn*cl)n2Pdy!j6;kzd=aHnl_&x>XlgEQa`yX9{JNR2k#H>( zrT1Qu+LAIne#u%>uARfj@nC%8<7UL&0MBykh0TXg&X8-!Bs4--z%6a2GCvhqiS$Dfmn;h4VBK*Byo0~w`eX}k>yredO1 z7g^wsq2Pyo{i&qUIB8RL8`Y>b5)FlHSW~VFzVzNp?7U*Mk3Xq9vnJxLty(uo#i*dP zqu~BU#Bq4{RW&$w7eIq~*z={I_>-;LR{E1X#c!c&ELp0_f{oq~)-0k9TgaB1hQWrV zMIvb4AgPv?h#co%5dhD+3>lVhI8HJ*Ra@J;vpM*eq3No)c29zw&4FXoN>CW!r2V4- zf%!MgM8I>`YS;cjGAj^gWysJRi-|ZQpeJU$ZUhL@_sskod>K|4K-6U`VmxaAONqs{ zo^nFc+bdZL6JF_-;hB8{Zwe2;Z?J(1eceD@RDG3=U)5{pNHwLkdVOQ2B-xAWn}5?9 z!plg`WLVs4Yt;Lvl-{s|dXvrc+Ki%Xm%V2S#^)R`Zqt$AxvSKM2fTO_ett2I}r!pJ*P;`Zq=ISw~MG=;LPLsQw6 zF{u!LOvd)#LpUqeAfkjLOfNzs78tx7Z$|#l8k}>CgW^Lnj(XTt$d80Mu7)~10Alct zePMsCH2=sdzm|F#F34mL7L>9_V{aeE9^9xptan_+?W+cHLy%wAfaGHICDe|vguu}{ zX|uHS9ti^4nmPP}rJDhwO^

T z-Cz(E^`~#KG(n9axxH)XO&2`I36Zg|vlMmk^ynXIHv*tD{Em3aERE1DF4vMKrSVLs8@HH*%Yn=XNahZHZ}W!4Q>Dt;lI?`>!Z z{rVeA5mKYIR{yY^-m2d}h_Qu{?iXi;Wd|K;tBG=$6=zaN6vq-B5I4>6Fg)W%t`1i7 z;On(OR51;DBos@&>tk~?#x8!@CmnBjb88cMPkZb2i?|0>-T*x3ST|BMfhKn)5Si=; z8xtQc>g)N?oebMDnhH}Cn0kPaMrv%RJq{;vVel^|UaN954 zWog-rX;A-^zt>f)!*$M932GB8%*&;i&oo|a@-D)Un%JkwBJ0k-3z?7)3J@Tou^^_m z3o@`CPm^rnb&WK!{1s-YskwPiv_BQj;3$ zd=lCZpN+GI65)?rfz6zCBU6GC<{x-T@nOc`5Fs`LmS_OTj!^c~2Yoi#KkURQeJi|% zMqf>3RH!Eqww`mSo2nXQ`W_4UzS5dyt=W7>Bo5C9*YVq`XQ~(RoxYF3 zjUW^8(B+oZW)^B@(_zT`jX9hNmI1lfA8prrjUhdwoc$dAYSW4Wk48iNB6}2~+LK>% z6}Xpl5+S%=1F`(5F<#;L^)0`~8#Jnn(DeOYPnl9KdYY8S31z2@&r&t`aUNMF%e(>w zJ|1Me(JeP`$f3w?zYDw@Vl;3QKKLAIPvC&$8E8s={VbDsBBYPwRh`%wHq!KXz6|b{ zrot2=mWtN;JJJ)suH?)VDL&=y<6x#2fl%25lYI<~o)!Eig6Tk0|Hc^+rzjn$3`I01 z=9Fl_Y{E_{$7urU{xgOqeybtDJ&u-s7?$30bR^}O8^CO1ThLJ2tW?fzm9 zDf^yrn)G70#^5ihYDwjL)Qr_Gy6Edbp4!jUia>zdZNlp=26wDTPDHCF*Xj`Bywe?C$Eso7q?;{Rc5AG+J@@~}apg%Xb% zPFec_4JZr7fqM}W{0?)Ea)F1UNYm{2qwti9pRC`6+sVJbCMyo44g=#aMM=0l-%1oa zFq}!MN}}Bjv|YIUU`pSn#B+{wQK}JJ&Ug6`Tb~5ggA%A+t1|G@ zQG+!`&w!*zKTY$0&H=`eV+sI6&>ti#R-oCM}E7b59o`$RZC4+DH zLyh8dK?r{ax`r+Bz<8>j2&8acbrSVQSJrH=x7Xs5eG<2A;2>)$C`+vPM8RozwkD$jNpjqezsB}>i z-HSvoQScXQvqIhaZKK`~Y{waAGiR z{pdA^@Z0x=@{jjm5FpPWf64^r@G-zGIT2>mkv@G?&f^J1IpZB=R9_4`N*3T*jo`VT zs{mSdTJm11+G5oB{TZ9Yn5Y?@0hakl*r1Mo>pnWVky)qlMc*7CF619cP67Z>WvT$J z{E%06kpuuP3U!8uwnQi&Z7&bn^tP#Hy|X=lg80D}S?CZOgcZjrLj=%LEkDFgJrEJ3 z&JCRn7&j(v8yGr}b3lV_I9XDCw){?+=}4)tAYAXm6oC;fxDGY4u=IB^jvba7p{{Oi zxST#$83|GbL>%i>kU=9khtu!FgW?|rEKPB1U+~=mFaG*FK3ROm+Dwi<^}TGRw27vi zL9occ-d|D$y}OmY{q__fNFpP2v-oeRvDW#6n4QhL4^uxsM~uqB2VW|jr&8DhjI+R1 z%tYKdYB?SWBD4(DxyX3Lv0FO^#3Q=h?J3B7qg7U~Y%C9876-&LJYCOftDc?KpL`%o zKQ9$dEu4rvm!Ch9BA5u?V1m@hA*?Dvdu>rat-E%=0EC?kL;bxzDDN8SS!D@H1F;b+ zqw0zh_JYFgNJkb=Fr08M>8$Mb=QnZ(!!>H$881|Q_yTi2E4WW>)tVi_yDt$I1e0XM z+%3e6kR9d=nkpKzhA%&sgll5o#`$u~$2sjvD_=PIkWX-0QVI>6TJm_oDRU9V#n|fs zmgE9~bXYM(e~t%x5~ye$kkE*70?_c9l$0|0B=`eq9FSqo5G#E0WV`?ezWL30Z@IwF zB=dwm%_Y&hQ}IzB^#QnZmWpFJ^J-Nx?3PJgX(^G2h*|S~8XzYWjMg4r=AAspp82(n z8c#dWlGw=@Ebd60G#Kl#;Z;hbd_sNP9VgDCb>#HxOr*lkLMwbE3i~_G(lGjJi)G#D zxjBjg1XKYzn7`czIwNKsYZo@4H0cM!Qc3y2*Is^ z^#Rv*)B~(VpT0*+-IMy#%apo8-iuwQYQsZWCsddTAsdJ|v^VVPv|q^a?Tu@$rj-HG zwrTj(-yw^fu>rh*<)_3aj*aK0izLn*KaRrZrf%t+%(VgAs(77@-l)^;n@#IYDpMsY zyPp!zue&EK=y_7Ao(1Ll;`GyP;n;MQAxp2-Cf6lE{H5upS8xm0n@|8GcC9K>8d@ik zF%#?@uo_5G5*jc{C6zWZI`YrKO}wAqm)Zg6qWOzd@JNd_-vtpRk73sazOVyNnJpsc z3%}LDTbmJ@OP?X{2U2*80|SVqJ|xBHsh79fyq@p@8x;{Vo9Ig<=E^#3zr}we1|WSE z^!K`d$|snMgBKAkd;oy?sL27FuksW0540DS_*rYFZak&$>csLahyvjf{a<=Dp8x36 zx<)x{VAyvuhY~L%2`;D#Kfu}P>oqv!sfD?Q#Rc{K^S1{u2c=QXz8kESGm6v|ED>BL zWL&T$t`VgP0KBfHG4I@|QgHDi80zchLnG~7mYu;b;H z|IyBQ1~s{*aXbPML?Lh~m);a9UWx<-LN9`V30(z25N_y!P^A|s7OIFMp-54p(nR7F zqd@2;KoUR&0SQP8(rZj;``+FCurs?mvpcgN_v88SoDa`?&YXAV`Tu^B^bZ@m8$IPY z91_nU&(3;v4{k5EL12!on|~&D#S9blqN5biHCEgb8F_3yqfZG9ic07j`2UyxQ8z#( zdeM2R?ky6NuC~Ak2_+sK9SzpT3c`Rsg52=ocUfBtrU?lNJ7HwP*SEpa@W<-11bwS= z3!uN$Vx2JG>F%r@FN;Z|-Yxn-3h=qSTI&SyKrxqo^oAK>9Z2}e8wNNkFwsd1aaZ=a-*weh&DA zef%a;{y32NzO}VQ^s*{YRe|3KgA*}=6+i}c9Xn=oeP|g=l5W5GggN-7*@9j%pJFte zPfKwYmeym02oLd<^MT37gAMWMc_FtC)qCY}+vN~9{HoC3Qs?w_>eyfX9unG$YtLdy zq;g7bmw1jKAnaJ}`n#p$AEq$>r1Vhe9LZ3nC3Ym&w@xMESE6WRBg~21xtp__bpdTt z0XcDWQpa#ISyyAQb!ViYPH&Ajyq_BwXv&@k69ttiK4w~vCBaLhC*ASn_Q}1au@Rkq z2>%zx4X89%ZCI)H0pE=l^gMW3FEpFCWScAh52AKr_vSHFgsvsvWSI3D-rT(k>~^Gu zq~gm+Fjcb79?w_x5}win-=ik+3()tLedM^LZ1(+zS_B+lx*|!B&Oju!yG7B zp>*x>CkmJSyvdIoHt5t^HFl~RHB7dxl(OebrCa9N5MMu}=l67gj9?+#QbK2(5Oq0E z-WjdjL`6aiFJ&=8TWIceQ->o9iDUH<%eq%rAjN;UC*k@-Cnq}Q^_5Ig?>Lsb zpmfw$2@LU7BeuH5xDOtmy@{2ywGh`tZixco55f!k#?mzrv>)l)``@59aOt-AE~Y24XTPM3ODK2+Ygqxg(lkuOXA;hF6I$eZQRvjupj@ z-8Q!=G*zhApWHOS#Z3ocG;S)P48t)oJQBu$l|-n`?d_sEfIVYiJ4+T}kLm;(iC>p) zp|t{r?i=hbL0P~AMA@|h#|7<^-=kc4E)svZ582?=XF)2*>Mw#``S%+9IOm}qv7(Edyi~^HyQ?YCkk3AT1cX(g(NUOx0uzDnSuH@bNW7JMh`0XQ~G3RV;31|DEJ}T(h zwg6Z!#hX}~(Zbxh=JE#Bit7UuDp(DnYTKV&{EBc^-&VgCZE%&zeO9E@L0*`2hrhMT zUEq}9D;4L~6!S~P#_oxmJkp0|$9|Eoq8J_)lU@t6|CHF|Kx1$R4ufZi-pHS8Q+}HV z0yX@5e$XNeRrn#(tt0KY!rVf9!f;*u(Sa-OY4hoFya8RQd692mJRG)*k&+dFALe4r z-9&41zn0R~Zxa?tRtZIY`hedJ~mZ z96+t%(Vt_9YIxt1uqkeS{6h<+(mG|)6>vY6c2H^p3jTtg^d!{yMA@Vn>gWnvqiptiImdzm!k!Y>_6~U%iHaBstaNq zlNnckfHp9$IF+vxi(siZbL*rRe!dg`8`sWG3U-0F$o_bsw@ojv+l9p=iknx0^1{@m z;fCNqfQLk`-n9=k#rEEe3iy$>8pQ(OOT-DB}1Ty(a{J`08m??8YN7jtKm zy`Z0|0h*`-ZvJr>9?ZQ+g8mw067sleR7%(?l4qTe zk}vx}G%<*n(9ju9f$M-2E*EQiny|f6N8aa6f)0F9B+x|lBY1gy!b00?F5tZ1i>tUc zIsvC|>jlVXcwo~LBc$St2#m+I@Lksu7UAOQQyAwIPdI#0{k@u?XBrTwq1g+F1Qf4K zu(1)Qrm;h%)4w8J*#Xv!B~#*3ER`yW(2ALeY&Z>|qn?uQ&<-&ppgBdmN{2x~CBOPz z`SrJGwNeY%+PS`)Asl3&xYN)&SNraGQGP#Vb4j3Wxa}-VDwh++uYv5B^%jR+SX)ul z>Qe#nG7x|^Vkpr(G4jwlyLuSaiCM&fB=m1b`S`xxaPp^)mfZGL+C_2~^T|}RV?V(8 z22xUmCELM1Qr^5qDGlOVr;W?d(h}Rg1722Y61e+bF`$WxH|b#~(oGM9inKg_!qF=O zjYrGPV$k@hxJu4o$S%|g<$V&L0M0L#!d-p3+JgOPS9vdm*P-V6gGa@*UQjf<_IAosxzWLCwYP%u)tp}Kjo=nSz={-4*BqA3yy#Kt` zJt5t?aH&&Pd5$sQBcxxDXEZ$GnJtJ7-gSY~RjY|V?O-C*@>c-L9@875gmgIIB*nw|055GWnjQz;WO_~DuS6Fr|}@ZX1! zdzs)PQQhqizeAIJ)UKaAAk&o^!HQBgab2GZJ1M2j%ul7~HAVZH0Jt-wtTNDe`@35_ z9WE{+0^G$%02|WryD64kljI~U%&2C}N>aaNWgqo@6Hz07G_zPyQUs9PR8>Q$Y0L6V z&pvd>odf(i_^1;l*RbP_*QS}m9^O&orgjGe3JQ0 zjrxo&H;Z*%aUMC#sa*2a84d4-k{OuH72oj`pLOx7vKg--SL27^6a9)5L+DW>Ssa;3 ziqBy)$O+G4s1No;OXx;oy)MncM~_q<^GGioE_1WP$Rt{RPy1wv7M) literal 0 HcmV?d00001 diff --git a/model-examples/reasoning-models/mini-r1-zero/tutorials-docs/train_a_mini_r1_zero_from_scratch.md b/model-examples/reasoning-models/mini-r1-zero/tutorials-docs/train_a_mini_r1_zero_from_scratch.md new file mode 100644 index 0000000..5f989a2 --- /dev/null +++ b/model-examples/reasoning-models/mini-r1-zero/tutorials-docs/train_a_mini_r1_zero_from_scratch.md @@ -0,0 +1,458 @@ +# 从零复现一个小型的 R1-Zero + + +## 目录 + +- [DeepSeek R1 Zero 介绍]() +- [RL方法 - GRPO]() +- [奖励建模 (Reward Modeling)]() +- [训练模版提示 (Training Template)]() +- [创建数据集]() +- [在一个简单任务中训练 - Countdown Task]() +- [在你的环境运行](../README.md) + + +## 1. DeepSeek R1 Zero 介绍 + +**DeepSeek R1 Zero** (后简称 R1-Zero) 是通过纯强化学习进行训练的,使用了 **GRPO, Group Relative Policy Optimization** 强化学习算法, +设计了带有特定思考格式的训练模版提示, 并应用了 **准确率奖励(Accuracy rewards)** 以及 **格式奖励(Format rewards)** 进行训练。 + +![](imgs/r1_zero_process.png) + + +## 2. RL方法 - GRPO +**GRPO** 算法中使用了 **组相对奖励(Group Relative Rewards)** 替换 **PPO** 算法中的 **Value Model** 和 **GAE** 用于 **优势函数计算(Advantage)**,节省显存并使得算法更加简洁。 + +算法结构图如下: + +![](imgs/grpo.png) + +优化过程如下: + +![](imgs/grpo_2.png) + +MindSpore代码实现如下: + +[源码位置](../src/grpo.py) + +```python +import numpy as np +from typing import Callable, Optional, List +from transformers import PreTrainedTokenizer, GenerationConfig + +import mindspore as ms +from mindspore import nn, ops, Tensor + +class GRPO(nn.Cell): + def __init__( + self, + policy_model: Optional[nn.Cell], + reference_model: Optional[nn.Cell], + reward_funcs: List[Callable], + tokenizer: PreTrainedTokenizer, + args + ): + super(GRPO, self).__init__() + + self.policy_model = policy_model + self.reference_model = reference_model + self.reward_funcs = reward_funcs + self.tokenizer = tokenizer + + # Training arguments + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.beta = args.beta + + self.generation_config = GenerationConfig( + max_new_tokens=args.max_completion_length, + do_sample=True, + temperature=args.temperature, + num_return_sequences=args.num_generations, + pad_token_id=tokenizer.pad_token_id, + ) + + def get_completion_and_reward(self, batch): + prompts, nums, targets, prompt_ids, attention_mask = \ + batch["prompts"], batch["nums"], batch["targets"], batch["prompt_ids"], batch["attention_mask"] + prompts = np.array(prompts).tolist() + nums = [np.array(b).tolist() for b in batch["nums"]] + + # FIXME: unpad + assert prompt_ids.shape[0] == 1, "not support bs>1 when generate task" + prompt_ids = prompt_ids[:, :attention_mask.sum()] + attention_mask = attention_mask[:, :attention_mask.sum()] + + completion_ids = self.policy_model.generate( + input_ids=Tensor(prompt_ids, ms.int32), + attention_mask=Tensor(attention_mask, ms.bool_), + generation_config=self.generation_config, + max_new_tokens=self.max_completion_length, + use_cache=False, + ) + completion_ids = completion_ids.asnumpy() + prompt_completion_ids = np.concatenate([prompt_ids.repeat(self.num_generations, axis=0), completion_ids], axis=-1) + num_logits_to_keep = completion_ids.shape[1] + + # Mask everything after the first EOS token + is_eos = np.array(completion_ids == self.tokenizer.eos_token_id) + eos_idx = np.full((is_eos.shape[0],), is_eos.shape[1], dtype=np.int) + eos_idx[is_eos.any(axis=1)] = is_eos.astype(np.int).argmax(axis=1)[is_eos.any(axis=1)] + sequence_indices = np.arange(is_eos.shape[1])[None].repeat(is_eos.shape[0], axis=0) + completion_mask = (sequence_indices <= eos_idx[:, None]).astype(np.int) + + # get reward + # decode the generated completions + completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True) + prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] + + rewards_per_func = np.zeros((len(prompts), len(self.reward_funcs)), dtype=np.float32) + for i, reward_func in enumerate(self.reward_funcs): + rewards_per_func[:, i] = reward_func(completions=completions, nums=nums, targets=targets, prompt=prompts) # Shape (B*G,) + rewards = rewards_per_func.sum(axis=1) + + return prompt_completion_ids, num_logits_to_keep, completion_mask, rewards + + def compute_loss( + self, + prompt_completion_ids: Tensor, + num_logits_to_keep: int, + completion_mask: Tensor, + rewards: Tensor + ) -> Tensor: + + # 1. compute grpo reward and advantages + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(axis=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(axis=1) + + # normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # 2. compute kl divergence + per_token_logps = self.get_log_probabilities( + prompt_completion_ids, + self.policy_model(prompt_completion_ids)[0], + num_logits_to_keep, + ) + ref_per_token_logps = self.get_log_probabilities( + prompt_completion_ids, + self.reference_model(prompt_completion_ids)[0], + num_logits_to_keep, + ) + kl_divergence = ops.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + + # x - x.detach() allows for preserving gradients from x + per_token_loss = ops.exp(per_token_logps - ops.stop_gradient(per_token_logps)) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * kl_divergence) + loss = ((per_token_loss * completion_mask).sum(axis=1) / completion_mask.sum(axis=1)).mean() + + return loss + + def get_log_probabilities(self, input_ids: Tensor, logits: Tensor, num_logits_to_keep: int) -> Tensor: + # logits: (B, L, V), prompt_completion_logits -> completion_logits + logits = logits[:, -(num_logits_to_keep+1):-1, :] + + # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. + per_token_logps = () + + for i in range(logits.shape[0]): + logits_row, input_ids_row = logits[i], input_ids[i, -num_logits_to_keep:] + + log_probs = ops.log_softmax(logits_row, axis=-1) + token_log_prob = ops.gather_elements(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) + per_token_logps += (token_log_prob,) + + return ops.stack(per_token_logps) +``` + + +## 3. 奖励建模 (Reward Modeling) + +**R1-Zero** 中使用了 **格式奖励(Format rewards)** 以及 **准确率奖励(Accuracy rewards)** 进行训练,代码如下: + +[源码位置](../src/rewards.py) + +**格式奖励(Format rewards):** + +```python +import re +import numpy as np + +def format_reward(completions: list[str], *args, **kwargs) -> np.ndarray: + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?\s*.*?$" + # add synthetic as its already part of the prompt and prefilled for the assistant to more easily match the regex + matches = [re.match(pattern, "" + content, re.DOTALL | re.MULTILINE) for content in completions] + return np.array([1.0 if match else 0.0 for match in matches]) +``` + +**准确率奖励(Accuracy rewards):** + +> 这里以 **Countdown Game** 作为例子 + +```python +import re +import numpy as np + +def countdown_game_accuracy_reward(completions: list[str], nums: int, targets: int, *args, **kwargs) -> np.ndarray: + """ + For Countdown Game, evaluates completions based on: Mathematical correctness of the answer + + Args: + completions (list[str]): Generated outputs + targets (int): Expected answers + nums (int): Available numbers + + Returns: + list[float]: Reward scores + """ + rewards = [] + for completion, gt, numbers in zip(completions, targets, nums): + try: + # Check if the format is correct + match = re.search(r"(.*?)<\/answer>", completion) + if match is None: + rewards.append(0.0) + continue + # Extract the "answer" part from the completion + equation = match.group(1).strip() + # Extract all numbers from the equation + used_numbers = [int(n) for n in re.findall(r'\d+', equation)] + + # Check if all numbers are used exactly once + if sorted(used_numbers) != sorted(numbers): + rewards.append(0.0) + continue + # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace + allowed_pattern = r'^[\d+\-*/().\s]+$' + if not re.match(allowed_pattern, equation): + rewards.append(0.0) + continue + + # Evaluate the equation with restricted globals and locals + result = eval(equation, {"__builti'ns__": None}, {}) + # Check if the equation is correct and matches the ground truth + if abs(float(result) - float(gt)) < 1e-5: + rewards.append(1.0) + else: + rewards.append(0.0) + except Exception: + # If evaluation fails, reward is 0 + rewards.append(0.0) + return np.array(rewards) +``` + + +## 4. 训练模版提示 (Training Template) + +**R1-Zero** 论文中的训练模版提示: + +```text +A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here + answer here . User: prompt. Assistant: +``` + +以 **Countdown Game** 为例子,代码如下: + +[源码位置](../src/dataset.py:22) + +```python +def generate_r1_prompt(numbers, target): + r1_prefix = [{ + "role": "system", + "content": "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer." + }, + { + "role": "user", + "content": f"Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final equation and answer in tags, for example (1 + 2) / 3 = 1 ." + }, + { + "role": "assistant", + "content": "Let me solve this step by step.\n" + }] + return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), + "target": target} +``` + + +## 5. 创建一个 **Countdown Task** 数据集,并使用`mindspore.dataset`接口进行封装 + +[原始数据集格式](https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4) + +[源码位置](../src/dataset.py) + +```python +import numpy as np + +from transformers import PreTrainedTokenizer +from datasets import load_dataset +from mindone.transformers.mindspore_adapter.data import HF2MSDataset +import mindspore + + +def create_countdown_dataset( + hf_dataset_path: str = "Jiayi-Pan/Countdown-Tasks-3to4", + tokenizer: PreTrainedTokenizer = None, + batch_size: int = 1, + num_epochs: int = 1, + rank: int = 0, + rank_size: int = 1, +): + + # generate r1 prompt with a prefix for the model to already start with the thinking process + + # TODO: remove + def generate_r1_prompt(numbers, target): + r1_prefix = [{ + "role": "system", + "content": "You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer." + }, + { + "role": "user", + "content": f"Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final equation and answer in tags, for example (1 + 2) / 3 = 1 ." + }, + { + "role": "assistant", + "content": "Let me solve this step by step.\n" + }] + return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), + "target": target} + + # Load dataset from Hugging Face Hub + dataset = load_dataset(hf_dataset_path, split="train") + # select a random subset of 50k samples + dataset = dataset.shuffle(seed=42).select(range(50000)) + + # convert our dataset to the r1 prompt + dataset = dataset.map(lambda x: generate_r1_prompt(x["nums"], x["target"])) + + # split the dataset into train and test + train_test_split = dataset.train_test_split(test_size=0.1) + train_dataset = train_test_split["train"] + test_dataset = train_test_split["test"] + + # convert hf-datasets to mindspore-dataset + + def ms_data_collator(inputs, batch_info): + first = inputs[0] + assert isinstance(first, dict) + prompts = [x["prompt"] for x in inputs] + nums = [x["nums"] for x in inputs] + targets = np.array([int(x["target"]) for x in inputs]) + # prompt_inputs = tokenizer(prompts, return_tensors="np", padding=True, padding_side="left", add_special_tokens=False) + prompt_inputs = tokenizer( + prompts, + return_tensors="np", + padding="max_length", + truncation=True, + max_length=256, + padding_side="right", + add_special_tokens=False + ) + batch = { + "prompts": prompts, + "nums": nums, + "targets": targets, + "prompt_ids": prompt_inputs.input_ids, + "attention_mask": prompt_inputs.attention_mask, + } + return batch + + ms_train_dataset = mindspore.dataset.GeneratorDataset( + HF2MSDataset(train_dataset), column_names="item", shard_id=rank, num_shards=rank_size + ) + ms_train_dataset = ms_train_dataset.batch(batch_size=batch_size, per_batch_map=ms_data_collator) + ms_train_dataset = ms_train_dataset.repeat(1) + ms_train_dataset = ms_train_dataset.create_dict_iterator(num_epochs=num_epochs, output_numpy=True) + + ms_test_dataset = mindspore.dataset.GeneratorDataset( + HF2MSDataset(test_dataset), column_names="item", shard_id=rank, num_shards=rank_size + ) + ms_test_dataset = ms_test_dataset.batch(batch_size=1, per_batch_map=ms_data_collator) + ms_test_dataset = ms_test_dataset.repeat(1) + ms_test_dataset = ms_test_dataset.create_dict_iterator(num_epochs=1, output_numpy=True) + + return ms_train_dataset, ms_test_dataset +``` + + +## 6. 在一个简单任务中训练 - Countdown Task + +[Countdown Task](https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4) + +[源码位置](../train_r1_zero.py) + +### Step 1:创建数据集 + +```python +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") +tokenizer.pad_token = tokenizer.eos_token + +train_dataset, test_dataset = create_countdown_dataset("Jiayi-Pan/Countdown-Tasks-3to4", tokenizer=tokenizer) +``` + +### Step 2:创建模型 + +> 这里使用 **Qwen2.5-1.5B-Instruct** 模型 + +```python +import mindspore +from mindone.transformers import Qwen2ForCausalLM + +policy_model = Qwen2ForCausalLM.from_pretrained( + "Qwen/Qwen2.5-1.5B-Instruct", + mindspore_dtype=mindspore.bfloat16, + return_dict=False, +) +reference_model = Qwen2ForCausalLM.from_pretrained( + "Qwen/Qwen2.5-1.5B-Instruct", + mindspore_dtype=mindspore.bfloat16, + return_dict=False, +) + +grpo_model = GRPO(policy_model, reference_model, [format_reward, countdown_game_accuracy_reward], tokenizer, args) +``` + +### Step 3:创建优化器和MindSpore训练模型 + +```python +class TrainNet(nn.Cell): + def __init__(self, grpo_model: GRPO): + super(TrainNet, self).__init__(auto_prefix=False) + self.grpo_model = grpo_model + + def construct(self, *args, **kwargs): + loss = self.grpo_model.compute_loss(*args, **kwargs) + return loss + +optimizer = nn.AdamWeightDecay(policy_model.trainable_params(), learning_rate=5e-7) +train_model = TrainOneStepWrapper(TrainNet(grpo_model), optimizer) +``` + +### Step 4:启动训练 + +```python +train_model.set_train() +for step, batch in enumerate(train_dataset): + batch = batch["item"] + prompt_completion_ids, num_logits_to_keep, completion_mask, rewards = \ + grpo_model.get_completion_and_reward(batch) + + loss, _, overflow = train_model( + mindspore.Tensor(prompt_completion_ids, mindspore.int32), + num_logits_to_keep, + mindspore.Tensor(completion_mask, mindspore.bool_), + mindspore.Tensor(rewards, mindspore.float32), + ) + + print(f"step: {step}, loss: {loss}") +``` + + +## 6. 在你的环境运行 + +参考 [Mini-R1-Zero 示例代码 README](../README.md) From dd4d557bd3ae15f4f29af17d48868783c19dc16b Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Thu, 13 Feb 2025 20:58:34 +0800 Subject: [PATCH 2/2] add roadmap for reasoning tutorials --- model-examples/README.md | 2 ++ model-examples/reasoning-models/README.md | 14 ++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 model-examples/reasoning-models/README.md diff --git a/model-examples/README.md b/model-examples/README.md index 359b0de..21b8ee0 100644 --- a/model-examples/README.md +++ b/model-examples/README.md @@ -38,4 +38,6 @@ ### Reasoning/推理 +- [目录](./reasoning-models/README.md) + - [Mini DeepSeek R1 Zero 从零到一实践](./reasoning-models/mini-r1-zero/README.md) diff --git a/model-examples/reasoning-models/README.md b/model-examples/reasoning-models/README.md new file mode 100644 index 0000000..d3bc880 --- /dev/null +++ b/model-examples/reasoning-models/README.md @@ -0,0 +1,14 @@ +## Reasoning Tutorials + +### Updating + +- [mini r1 zero](./mini-r1-zero/README.md): A minimal reproduction of [DeepSeek R1 Zero](https://github.com/deepseek-ai/DeepSeek-R1) with native MindSpore. + +### TODO List + +- (TODO) open-r1: a reproduction of [huggingface/open-r1](https://github.com/huggingface/open-r1) +- (TODO) simpleRL-reason: a reproduction of [hkust-nlp/simpleRL-reason](https://github.com/hkust-nlp/simpleRL-reason) +- (TODO) search-and-learn: a reproduction of [huggingface/search-and-learn](https://github.com/huggingface/search-and-learn) +- (TODO) deepscaler: a reproduction of [agentica-project/deepscaler](https://github.com/agentica-project/deepscaler) +- (TODO) OpenThinker: a reproduction of [open-thoughts/open-thoughts](https://github.com/open-thoughts/open-thoughts) +- (TODO) s1: a reproduction of [simplescaling/s1](https://github.com/simplescaling/s1) by Li Fei-Fei teams