Skip to content

[fix] Fix KeyError getting project id for HuggingFace and Sklearn bac… #790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion label_studio_ml/examples/bert_classifier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def fit(self, event, data, **additional_params):
if event not in ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'):
logger.info(f"Skip training: event {event} is not supported")
return
project_id = data['annotation']['project']
logger.debug(f"Project details payload for training: {data}")
project_id = data['project']['id']

# dowload annotated tasks from Label Studio
ls = label_studio_sdk.Client(self.LABEL_STUDIO_HOST, self.LABEL_STUDIO_API_KEY)
Expand Down
9 changes: 5 additions & 4 deletions label_studio_ml/examples/huggingface_ner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_labels(self):
from_name, _, _ = li.get_first_tag_occurence('Labels', 'Text')
tag = li.get_tag(from_name)
return tag.labels

def setup(self):
"""Configure any paramaters of your model here
"""
Expand Down Expand Up @@ -102,7 +102,7 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
'score': avg_score / len(results),
'model_version': self.get('model_version')
})

return ModelResponse(predictions=predictions, model_version=self.get('model_version'))

def _get_tasks(self, project_id):
Expand Down Expand Up @@ -135,15 +135,16 @@ def tokenize_and_align_labels(self, examples, tokenizer):

tokenized_inputs["labels"] = labels
return tokenized_inputs

def fit(self, event, data, **kwargs):
"""Download dataset from Label Studio and prepare data for training in BERT
"""
if event not in ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING'):
logger.info(f"Skip training: event {event} is not supported")
return

project_id = data['annotation']['project']
logger.debug(f"Project details payload for training: {data}")
project_id = data['project']['id']
tasks = self._get_tasks(project_id)

if len(tasks) % self.START_TRAINING_EACH_N_UPDATES != 0 and event != 'START_TRAINING':
Expand Down
5 changes: 3 additions & 2 deletions label_studio_ml/examples/sklearn_text_classifier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_label_studio_parameters(self) -> Dict:
'value': value,
'labels': labels
}

def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse:
"""
This method is used to predict the labels for a given list of tasks.
Expand Down Expand Up @@ -162,7 +162,8 @@ def fit(self, event, data, **kwargs):
logger.info(f"Skip training: event {event} is not supported")
return

project_id = data['annotation']['project']
logger.debug(f"Project details payload for training: {data}")
project_id = data['project']['id']
tasks = self._get_tasks(project_id)

# Get the labeling configuration parameters like labels and input / output annotation format names
Expand Down