Skip to content

Commit 2ae9e99

Browse files
authored
Merge pull request #42 from cov-lineages/dev
Handle multiple levels of definition within a constellation
2 parents fad0291 + f7cba5c commit 2ae9e99

File tree

3 files changed

+70
-38
lines changed

3 files changed

+70
-38
lines changed

scorpio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
_program = "scorpio"
2-
__version__ = "0.3.15"
2+
__version__ = "0.3.16"

scorpio/scripts/type_constellations.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,10 @@ def parse_json_in(refseq, features_dict, variants_file, constellation_names=None
295295
variant_list.append(record)
296296

297297
if "rules" in json_dict:
298-
rules = json_dict["rules"]
298+
if type(json_dict["rules"]) == dict and "default" in json_dict["rules"]:
299+
rules = json_dict["rules"]
300+
else:
301+
rules = {"default": json_dict["rules"]}
299302

300303
in_json.close()
301304
sorted_variants = sorted(variant_list, key=lambda x: int(x["ref_start"]))
@@ -346,9 +349,9 @@ def parse_csv_in(refseq, features_dict, variants_file, constellation_names=None,
346349
csv_in.close()
347350
rules = None
348351
if len(compulsory) > 0:
349-
rules = {}
352+
rules = {"default": {}}
350353
for var in compulsory:
351-
rules[var] = "alt"
354+
rules["default"][var] = "alt"
352355
sorted_variants = sorted(variant_list, key=lambda x: int(x["ref_start"]))
353356
return sorted_variants, name, rules
354357

@@ -570,24 +573,24 @@ def var_follows_rules(call, rule):
570573
else:
571574
return call == rule_call
572575

573-
def counts_follow_rules(counts, rules):
576+
def counts_follow_rules(counts, rules, key):
574577
# rules allowed include "max_ref", "min_alt", "min_snp_alt"
575578
is_rule_follower = True
576579
notes = []
577-
for rule in rules:
580+
for rule in rules[key]:
578581
if ":" in rule:
579582
continue
580583
elif str(rule).startswith("min") or str(rule).startswith("max"):
581584
rule_parts = rule.split("_")
582585
if len(rule_parts) <= 1:
583586
continue
584587
elif len(rule_parts) == 2:
585-
if rule_parts[0] == "min" and counts[rule_parts[1]] < rules[rule]:
588+
if rule_parts[0] == "min" and counts[rule_parts[1]] < rules[key][rule]:
586589
is_rule_follower = False
587-
elif rule_parts[0] == "max" and counts[rule_parts[1]] > rules[rule]:
590+
elif rule_parts[0] == "max" and counts[rule_parts[1]] > rules[key][rule]:
588591
is_rule_follower = False
589592
else:
590-
counts["rules"] += 1
593+
counts["rules"][key] += 1
591594
elif len(rule_parts) == 3:
592595
part = None
593596
if rule_parts[1] in ["substitution", "snp"]:
@@ -596,24 +599,27 @@ def counts_follow_rules(counts, rules):
596599
part = "indel"
597600
if not part:
598601
is_rule_follower = False
599-
elif rule_parts[0] == "min" and counts[part][rule_parts[2]] < rules[rule]:
602+
elif rule_parts[0] == "min" and counts[part][rule_parts[2]] < rules[key][rule]:
600603
is_rule_follower = False
601-
notes.append("%s_%s_count=%i is less than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[rule]))
602-
elif rule_parts[0] == "max" and counts[part][rule_parts[2]] > rules[rule]:
604+
notes.append("%s_%s_count=%i is less than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[key][rule]))
605+
elif rule_parts[0] == "max" and counts[part][rule_parts[2]] > rules[key][rule]:
603606
is_rule_follower = False
604-
notes.append("%s_%s_count=%i is more than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[rule]))
607+
notes.append("%s_%s_count=%i is more than %i" % (part, rule_parts[2], counts[part][rule_parts[2]], rules[key][rule]))
605608
else:
606-
counts["rules"] += 1
609+
counts["rules"][key] += 1
607610
else:
608-
logging.warning("Warning: Ignoring rule %s:%s" % (rule, str(rules[rule])))
611+
logging.warning("Warning: Ignoring rule %s:%s" % (rule, str(rules[key][rule])))
609612
return is_rule_follower, ";".join(notes)
610613

611614
def count_and_classify(record_seq, variant_list, rules):
612615
assert rules is not None
613-
counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0, 'rules': 0,
616+
counts = {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0, 'rules': {},
614617
'substitution': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0},
615618
'indel': {'ref': 0, 'alt': 0, 'ambig': 0, 'oth': 0}}
616-
is_rule_follower = True
619+
is_rule_follower_dict = {}
620+
for key in rules:
621+
is_rule_follower_dict[key] = True
622+
counts["rules"][key] = 0
617623

618624
for var in variant_list:
619625
call, query_allele = call_variant_from_fasta(record_seq, var)
@@ -623,20 +629,27 @@ def count_and_classify(record_seq, variant_list, rules):
623629
counts["substitution"][call] += 1
624630
elif var['type'] in ["ins", "del"]:
625631
counts["indel"][call] += 1
626-
if var["name"] in rules:
627-
if var_follows_rules(call, rules[var["name"]]):
628-
counts['rules'] += 1
629-
elif is_rule_follower:
630-
is_rule_follower = False
632+
for key in rules:
633+
if var["name"] in rules[key]:
634+
if var_follows_rules(call, rules[key][var["name"]]):
635+
counts['rules'][key] += 1
636+
elif is_rule_follower_dict[key]:
637+
is_rule_follower_dict[key] = False
631638

632639
counts['support'] = round(counts['alt']/float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']),4)
633640
counts['conflict'] = round(counts['ref'] /float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']),4)
634641

635-
if not is_rule_follower:
636-
return counts, False, ""
637-
else:
638-
call, note = counts_follow_rules(counts, rules)
639-
return counts, call, note
642+
for key in rules:
643+
if not is_rule_follower_dict[key]:
644+
continue
645+
else:
646+
call, note = counts_follow_rules(counts, rules, key)
647+
if call:
648+
counts["rules"] = counts["rules"][key]
649+
call = key
650+
return counts, call, note
651+
counts["rules"] = counts["rules"]["default"]
652+
return counts, False, ""
640653

641654

642655
def generate_barcode(record_seq, variant_list, ref_char=None, ins_char="?", oth_char="X",constellation_count_dict=None):
@@ -920,7 +933,15 @@ def combine_counts_call_notes(counts1, call1, note1, counts2, call2, note2):
920933
counts[key] = counts1[key] + counts2[key]
921934
counts['support'] = round(counts['alt'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4)
922935
counts['conflict'] = round(counts['ref'] / float(counts['alt'] + counts['ref'] + counts['ambig'] + counts['oth']), 4)
923-
call = call1 and call2
936+
if not call1 or not call2:
937+
call = False
938+
elif call1 == call2:
939+
call = call1
940+
elif call1 == "default":
941+
call = call2
942+
else:
943+
call = call1
944+
924945
note = note1
925946
if note != "" and note2 != "":
926947
note += ";" + note2
@@ -989,10 +1010,12 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na
9891010
best_support = 0
9901011
best_conflict = 1
9911012
best_counts = None
1013+
best_call = False
9921014
scores = {}
9931015
children = {}
9941016
for constellation in constellation_dict:
9951017
constellation_name = name_dict[constellation]
1018+
logging.debug("Consider constellation %s" %constellation_name)
9961019
parents = []
9971020
if not constellation_name:
9981021
continue
@@ -1015,20 +1038,25 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na
10151038
children[parent].append(constellation)
10161039

10171040
if call:
1041+
logging.debug("Have call for %s" %constellation_name)
10181042
if call_all:
1043+
if call != "default":
1044+
constellation_name = "%s %s" %(call, constellation_name)
10191045
lineages.append(constellation_name)
10201046
names.append(constellation)
10211047
elif constellation in children and best_constellation in children[constellation]:
1022-
continue
1048+
logging.debug("Ignore as parent of best constellation")
10231049
elif (not best_constellation) \
10241050
or (counts['support'] > best_support) \
10251051
or (counts['support'] == best_support and counts['conflict'] < best_conflict)\
10261052
or (counts['support'] == best_support and counts['conflict'] == best_conflict and counts['rules'] > best_counts["rules"])\
10271053
or (best_constellation in parents):
10281054
best_constellation = constellation
1055+
logging.debug("Set best constellation %s" %best_constellation)
10291056
best_support = counts['support']
10301057
best_conflict = counts['conflict']
10311058
best_counts = counts
1059+
best_call = call
10321060

10331061
if interspersion:
10341062
if counts["alt"] > 1:
@@ -1042,7 +1070,11 @@ def classify_constellations(in_fasta, list_constellation_files, constellation_na
10421070
counts['oth'], counts['rules'], counts['support'],
10431071
counts['conflict'], call, constellation, note))
10441072
if not call_all and best_constellation:
1045-
lineages.append(name_dict[best_constellation])
1073+
if best_call != "default":
1074+
best_constellation_name = "%s %s" % (best_call, name_dict[best_constellation])
1075+
else:
1076+
best_constellation_name = name_dict[best_constellation]
1077+
lineages.append(best_constellation_name)
10461078
names.append(best_constellation)
10471079

10481080
out_entries = [record.id, "|".join(lineages), "|".join([mrca_lineage_dict[n] for n in names])]

scorpio/tests/type_constellations_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ def test_parse_json_in():
147147
assert len([v for v in variant_list if v["type"] == "del"]) == 3
148148
assert len([v for v in variant_list if v["type"] == "aa"]) == 15
149149
assert name == "Lineage_X"
150-
assert rules["min_alt"] == 4
151-
assert rules["max_ref"] == 6
152-
assert rules["s:E484K"] == "alt"
150+
assert rules["default"]["min_alt"] == 4
151+
assert rules["default"]["max_ref"] == 6
152+
assert rules["default"]["s:E484K"] == "alt"
153153
assert mrca_lineage == "B.1.1.7"
154154
assert incompatible_lineages == "A|B.1.351"
155155

@@ -162,7 +162,7 @@ def test_parse_csv_in():
162162
assert len([v for v in variant_list if v["type"] == "del"]) == 3
163163
assert len([v for v in variant_list if v["type"] == "aa"]) == 15
164164
assert name == "lineage_X"
165-
assert rules["s:E484K"] == "alt"
165+
assert rules["default"]["s:E484K"] == "alt"
166166

167167

168168
def test_parse_textfile_in():
@@ -178,8 +178,8 @@ def test_parse_textfile_in():
178178
def test_parse_variants_in():
179179
in_files = ["%s/lineage_X.json" % data_dir, "%s/lineage_X.csv" % data_dir, "%s/lineage_X.txt" % data_dir]
180180
expect_names = ["Lineage_X", "lineage_X", "lineage_X"]
181-
rule_dict_json = {"min_alt": 4, "max_ref": 6, "s:E484K": "alt"}
182-
rule_dict_csv = {"s:E484K": "alt"}
181+
rule_dict_json = {"default": {"min_alt": 4, "max_ref": 6, "s:E484K": "alt"}}
182+
rule_dict_csv = {"default": {"s:E484K": "alt"}}
183183
rule_dict_txt = None
184184
expect_rules = [rule_dict_json, rule_dict_csv, rule_dict_txt]
185185

@@ -248,8 +248,8 @@ def test_count_and_classify():
248248
oth_string = "gaaattcgcccgta-gctcgcaatag"
249249
seqs = [Seq(ref_string), Seq(alt_string), Seq(alt_plus_string), Seq(oth_string)]
250250

251-
rules = {"min_alt": 1, "max_ref": 1, "snp2": "alt"}
252-
expect_classify = [False, False, True, False]
251+
rules = {"default": {"min_alt": 1, "max_ref": 1, "snp2": "alt"}}
252+
expect_classify = [False, False, "default", False]
253253
expect_counts = [{"ref": 5, "alt": 0, "ambig": 0, "oth": 1, "rules": 0, 'substitution': {'ref': 4, 'alt': 0, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 1, 'alt': 0, 'ambig': 0, 'oth': 1}, "support": 0.0, "conflict": 0.8333},
254254
{"ref": 1, "alt": 4, "ambig": 0, "oth": 1, "rules": 0, 'substitution': {'ref': 1, 'alt': 3, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 0, 'alt': 1, 'ambig': 0, 'oth': 1}, "support": 0.6667, "conflict": 0.1667},
255255
{"ref": 0, "alt": 5, "ambig": 0, "oth": 1, "rules": 3, 'substitution': {'ref': 0, 'alt': 4, 'ambig': 0, 'oth': 0}, 'indel': {'ref': 0, 'alt': 1, 'ambig': 0, 'oth': 1}, "support": 0.8333, "conflict": 0.0},

0 commit comments

Comments
 (0)