15 January 2017

Creating a custom GATK Walker (GATK 3.6) : my notebook


This is my notebook for creating a custom engine in GATK.


Description


I want to read a VCF file and to get a table of category/count. Something like this:


HAVE_ID TYPE COUNT
YES SNP 123
NO SNP 3
NO INDEL 13

Class Category


I create a class Category describing each row in the table. It's just a List of Strings


static class Category
        implements Comparable<Category>
        {
        private final List<String> labels ;
        Category(final List<String> labels) {
            this.labels=new ArrayList<>(labels);
            }

As were going to use Category in an associative array / Map we need to implement hashCode and equals:


Implementing hashCode and equals


static class Category
        implements Comparable<Category>
        {
        private final List<String> labels ;
        Category(final List<String> labels) {
            this.labels=new ArrayList<>(labels);
            }
        @Override
        public int hashCode() {
            return labels.hashCode();
            }
        @Override
        public int compareTo(final Category o) {
            for(int i=0;i< labels.size();++i)
                {
                int d = labels.get(i).compareTo(o.labels.get(i));
                if(d!=0) return d;
                }
            return 0;
            }
        @Override
        public boolean equals(Object o) {
            if(o==this) return true;
            if(o==null || !(o instanceof Category)) return false;
            return compareTo(Category.class.cast(o))==0;
            }
        }

# The Main Walker:


The main engine is called CountPredictions it extends a RodWalker. Why ? I don't really know, I've copied this from another engine. We re going to do some map/reduce some 'Map' so my class is declared as


public class CountPredictions 
  extends RodWalker< Map<CountPredictions.Category,Long>, Map<CountPredictions.Category,Long>>
  {
  ...
  }

# Documenting the Walker


An annotations 'DocumentedGATKFeature' is added so the GATKEngine will be able to find our walker


  (...)
  @DocumentedGATKFeature(
        summary="Count Predictions",
        groupName = HelpConstants.DOCS_CAT_VARMANIP,
        extraDocs = {CommandLineGATK.class} )
public class CountPredictions  extends RodWalker
  (....)

Describing the input and the ouput


The input is one VCF file:


@Input(fullName="variant", shortName = "V", doc="Input VCF file", required=true)
    public RodBinding<VariantContext> variants;

The output is a PrintStream (default is stdout) where we'll write the table:


@Output(doc="File to which result should be written")
    protected PrintStream out = System.out;

The other arguments


The other arguments are also decorared using java annotations.. These are the switches to add some columns to the final Table:


(...)
    @Argument(fullName="chrom",shortName="chrom",required=false,doc="Group by Chromosome/Contig")
    public boolean bychrom = false;
    @Argument(fullName="ID",shortName="ID",required=false,doc="Group by having/not-having ID")
    public boolean byID = false;
    @Argument(fullName="variantType",shortName="variantType",required=false,doc="Group by VariantType")
    (...)

The initialize() method


this method is called after the arguments have been parsed. As we want to be able to parse the VEP annotation where going to get the VCF header and extract the components of the ANN attribute to get the indexes for 'Annotation_Impact' and 'Transcript_BioType'


@Override
    public void initialize() {

        if(byImpact || bybiotype) {
            final VCFHeader vcfHeader  = GATKVCFUtils.getVCFHeadersFromRods(getToolkit()).get(variants.getName());
            final VCFInfoHeaderLine annInfo = vcfHeader.getInfoHeaderLine("ANN");
            if(annInfo==null)
                {
                logger.warn("NO ANN in "+variants.getSource());
                }
            else
                {
                int q0=annInfo.getDescription().indexOf('\'');
                int q1=annInfo.getDescription().lastIndexOf('\'');
                if(q0==-1 || q1<=q0)
                    {
                    logger.warn("Cannot parse "+annInfo.getDescription());
                    }
                else
                    {
                    final String fields[]=pipeRegex.split(annInfo.getDescription().substring(q0+1, q1));
                    for(int c=0;c<fields.length;++c)
                        {
                        final String column=fields[c].trim();
                        if(column.equals("Annotation_Impact"))
                            {
                            ann_impact_column=c;
                            }
                        else if(column.equals("Transcript_BioType")) 
                            {
                            ann_transcript_biotype_column=c;
                            }
                        }
                    }
                }
            }

        super.initialize();
    }

 reduceInit()


as far as I understand 'reduceInit()' is used to create a very first item during the map/reduce process. So we're creating an empty associative map:


@Override
    public Map<CountPredictions.Category,Long> reduceInit() {
        return Collections.emptyMap();
    }

the reduce method


This method reduce two mapping processes, so we're creating a map combining both counts:


@Override
    public Map<CountPredictions.Category,Long> reduce(Map<CountPredictions.Category,Long> value, Map<CountPredictions.Category,Long> sum) {
        final Map<CountPredictions.Category,Long> newmap = new HashMap<>(sum);
        for(Category cat:value.keySet()) {
            Long sv = sum.get(cat);
            Long vv = value.get(cat);
            newmap.put(cat, sv==null?vv:sv+vv);
        }
        return newmap;
    }

the map method


This is the workhorse of the engine. As far as I can see it is called for each Variant. The method 'tracker.getValues()' returns an array of all the variant (here, only one because we only have one input) at the current position. At the end of the method we must return a count of Category for those variants...


@Override
    public Map<CountPredictions.Category,Long> map(final RefMetaDataTracker tracker,final ReferenceContext ref, final AlignmentContext context) {
        if ( tracker == null )return Collections.emptyMap();
        final Map<CountPredictions.Category,Long> count = new HashMap<>();
        for(final VariantContext ctx: tracker.getValues(this.variants,context.getLocation()))
            {
            ...
            }
        return count;
        }

Filling the Category


for each variant, a new Category is filled:


(...)
    final List<String> labels=new ArrayList<>();
    if(bychrom) labels.add(ctx.getContig());
    if(byID) labels.add(ctx.hasID()?"Y":".");
    (...)
    final Category cat=new Category(labels);
    Long n=count.get(cat);
    count.put(cat, n==null?1L:n+1);
    (...)

At the end, onTraversalDone, printing the table


When the onTraversalDone is called, the table is printed using the class GATKReport:


@Override
    public void onTraversalDone(final Map<CountPredictions.Category,Long> counts) {
        GATKReportTable table=new GATKReportTable(
                "Variants", "Variants "+variants.getSource(),0);
        if(bychrom) table.addColumn("CONTIG");
        if(byID) table.addColumn("IN_DBSNP");
        if(byType) table.addColumn("TYPE");
        (...)
        table.addColumn("COUNT");

        int nRows=0;
        for(final Category cat: counts.keySet())
            {
            for(int x=0;x<cat.labels.size();++x)
                {
                table.set(nRows, x, cat.labels.get(x));
                }

            table.set(nRows, cat.labels.size(), counts.get(cat));
            ++nRows;
            }
        GATKReport report = new GATKReport();
        report.addTable(table);
        report.print(this.out);
        out.flush();
        }

All in one:


the full source code:


package mygatk;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.broadinstitute.gatk.engine.CommandLineGATK;
import org.broadinstitute.gatk.engine.GATKVCFUtils;
import org.broadinstitute.gatk.engine.walkers.RodWalker;
import org.broadinstitute.gatk.utils.commandline.Argument;
import org.broadinstitute.gatk.utils.commandline.Input;
import org.broadinstitute.gatk.utils.commandline.Output;
import org.broadinstitute.gatk.utils.commandline.RodBinding;
import org.broadinstitute.gatk.utils.contexts.AlignmentContext;
import org.broadinstitute.gatk.utils.contexts.ReferenceContext;
import org.broadinstitute.gatk.utils.help.DocumentedGATKFeature;
import org.broadinstitute.gatk.utils.help.HelpConstants;
import org.broadinstitute.gatk.utils.refdata.RefMetaDataTracker;
import org.broadinstitute.gatk.utils.report.GATKReport;
import org.broadinstitute.gatk.utils.report.GATKReportTable;

import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFInfoHeaderLine;

/**
 * Test Documentation
 */
@DocumentedGATKFeature(
        summary="Count Predictions",
        groupName = HelpConstants.DOCS_CAT_VARMANIP,
        extraDocs = {CommandLineGATK.class} )
public class CountPredictions  extends RodWalker<Map<CountPredictions.Category,Long>, Map<CountPredictions.Category,Long>> {
    private enum IMPACT { HIGH, MODERATE, MODIFIER,LOW};
    @Input(fullName="variant", shortName = "V", doc="Input VCF file", required=true)
    public RodBinding<VariantContext> variants;
    @Output(doc="File to which result should be written")
    protected PrintStream out = System.out;

    @Argument(fullName="minQuality",shortName="mq",required=false,doc="Group by Quality. Set the treshold for Minimum Quality")
    public double minQuality = -1;
    @Argument(fullName="chrom",shortName="chrom",required=false,doc="Group by Chromosome/Contig")
    public boolean bychrom = false;
    @Argument(fullName="ID",shortName="ID",required=false,doc="Group by ID")
    public boolean byID = false;
    @Argument(fullName="variantType",shortName="variantType",required=false,doc="Group by VariantType")
    public boolean byType = false;
    @Argument(fullName="filter",shortName="filter",required=false,doc="Group by FILTER")
    public boolean byFilter = false;
    @Argument(fullName="impact",shortName="impact",required=false,doc="Group by ANN/IMPACT")
    public boolean byImpact = false;
    @Argument(fullName="biotype",shortName="biotype",required=false,doc="Group by ANN/biotype")
    public boolean bybiotype = false;
    @Argument(fullName="nalts",shortName="nalts",required=false,doc="Group by number of ALTS")
    public boolean bynalts = false;
    @Argument(fullName="affected",shortName="affected",required=false,doc="Group by number of Samples called and not HOMREF")
    public boolean byAffected = false;
    @Argument(fullName="called",shortName="called",required=false,doc="Group by number of Samples called")
    public boolean byCalled = false;
    @Argument(fullName="maxSamples",shortName="maxSamples",required=false,doc="if the number of samples affected is greater than --maxSamples use the label \"GT_MAX_SAMPLES\"")
    public int maxSamples = Integer.MAX_VALUE;

    private final Pattern pipeRegex=Pattern.compile("[\\|]");

    static class Category
        implements Comparable<Category>
        {
        private final List<String> labels ;
        Category(final List<String> labels) {
            this.labels=new ArrayList<>(labels);
            }
        @Override
        public int hashCode() {
            return labels.hashCode();
            }
        @Override
        public int compareTo(final Category o) {
            for(int i=0;i< labels.size();++i)
                {
                int d = labels.get(i).compareTo(o.labels.get(i));
                if(d!=0) return d;
                }
            return 0;
            }
        @Override
        public boolean equals(Object o) {
            if(o==this) return true;
            if(o==null || !(o instanceof Category)) return false;
            return compareTo(Category.class.cast(o))==0;
            }
        }
    private int ann_impact_column=-1;/* eg: MODIFIER / LOW */
    private int ann_transcript_biotype_column=-1;/* eg: transcript /intergenic_region */

    @Override
    public void initialize() {

        if(byImpact || bybiotype) {
            final VCFHeader vcfHeader  = GATKVCFUtils.getVCFHeadersFromRods(getToolkit()).get(variants.getName());
            final VCFInfoHeaderLine annInfo = vcfHeader.getInfoHeaderLine("ANN");
            if(annInfo==null)
                {
                logger.warn("NO ANN in "+variants.getSource());
                }
            else
                {
                int q0=annInfo.getDescription().indexOf('\'');
                int q1=annInfo.getDescription().lastIndexOf('\'');
                if(q0==-1 || q1<=q0)
                    {
                    logger.warn("Cannot parse "+annInfo.getDescription());
                    }
                else
                    {
                    final String fields[]=pipeRegex.split(annInfo.getDescription().substring(q0+1, q1));
                    for(int c=0;c<fields.length;++c)
                        {
                        final String column=fields[c].trim();
                        if(column.equals("Annotation_Impact"))
                            {
                            ann_impact_column=c;
                            }
                        else if(column.equals("Transcript_BioType")) 
                            {
                            ann_transcript_biotype_column=c;
                            }
                        }
                    }
                }
            }

        super.initialize();
    }

    @Override
    public Map<CountPredictions.Category,Long> map(final RefMetaDataTracker tracker,final ReferenceContext ref, final AlignmentContext context) {
        if ( tracker == null )return Collections.emptyMap();
        final Map<CountPredictions.Category,Long> count = new HashMap<>();
        for(final VariantContext ctx: tracker.getValues(this.variants,context.getLocation()))
            {
            final List<String> labels=new ArrayList<>();
            if(bychrom) labels.add(ctx.getContig());
            if(byID) labels.add(ctx.hasID()?"Y":".");
            if(byType) labels.add(ctx.getType().name());
            if(byFilter) labels.add(ctx.isFiltered()?"F":".");
            if(minQuality>=0) {
                labels.add(ctx.hasLog10PError() && ctx.getPhredScaledQual()>=this.minQuality ?
                        ".":"LOWQUAL"
                        );
                }

            if(byImpact || bybiotype)
                {
                String biotype=null;
                IMPACT impact=null;
                final List<Object> anns = ctx.getAttributeAsList("ANN");
                for(final Object anno:anns) {
                    final String tokens[]=this.pipeRegex.split(anno.toString());
                    if(this.ann_impact_column==-1 ||
                            this.ann_impact_column >= tokens.length ||
                            tokens[this.ann_impact_column].isEmpty()
                            ) continue;
                    IMPACT currImpact = IMPACT.valueOf(tokens[this.ann_impact_column]);
                    if(impact!=null && currImpact.compareTo(impact)<0) continue;
                    impact=currImpact;
                    biotype=null;
                    if(this.ann_transcript_biotype_column==-1 ||
                            this.ann_transcript_biotype_column >= tokens.length ||
                            tokens[this.ann_transcript_biotype_column].isEmpty()
                            ) continue;
                    biotype=tokens[ann_transcript_biotype_column];
                    }
                if(byImpact) labels.add(impact==null?".":impact.name());
                if(bybiotype) labels.add(biotype==null?".":biotype);
                }
            if(bynalts)labels.add(String.valueOf(ctx.getAlternateAlleles().size()));
            if(byAffected || byCalled) 
                {
                int nc=0;
                int ng=0;
                for(int i=0;i< ctx.getNSamples();++i)
                    {
                    final Genotype g= ctx.getGenotype(i);
                    if(!(g.isNoCall() || g.isHomRef()))
                        {
                        ng++;
                        }
                    if(g.isCalled())
                        {
                        nc++;
                        }
                    }
                if(byCalled) labels.add(nc< maxSamples?String.valueOf(nc):"GE_"+maxSamples);
                if(byAffected) labels.add(ng< maxSamples?String.valueOf(ng):"GE_"+maxSamples);
                }

            final Category cat=new Category(labels);
            Long n=count.get(cat);
            count.put(cat, n==null?1L:n+1);

            }
        return count;
    }

    @Override
    public Map<CountPredictions.Category,Long> reduce(Map<CountPredictions.Category,Long> value, Map<CountPredictions.Category,Long> sum) {
        final Map<CountPredictions.Category,Long> newmap = new HashMap<>(sum);
        for(Category cat:value.keySet()) {
            Long sv = sum.get(cat);
            Long vv = value.get(cat);
            newmap.put(cat, sv==null?vv:sv+vv);
        }
        return newmap;
    }

    @Override
    public Map<CountPredictions.Category,Long> reduceInit() {
        return Collections.emptyMap();
    }


    @Override
    public void onTraversalDone(final Map<CountPredictions.Category,Long> counts) {
        GATKReportTable table=new GATKReportTable(
                "Variants", "Variants "+variants.getSource(),0);
        if(bychrom) table.addColumn("CONTIG");
        if(byID) table.addColumn("IN_DBSNP");
        if(byType) table.addColumn("TYPE");
        if(byFilter) table.addColumn("FILTER");
        if(minQuality>=0) table.addColumn("QUAL_GE_"+this.minQuality);
        if(byImpact) table.addColumn("IMPACT");
        if(bybiotype) table.addColumn("BIOTYPE");
        if(bynalts) table.addColumn("N_ALT_ALLELES");
        if(byCalled) table.addColumn("CALLED_SAMPLES");
        if(byAffected) table.addColumn("AFFECTED_SAMPLES");
        table.addColumn("COUNT");

        int nRows=0;
        for(final Category cat: counts.keySet())
            {
            for(int x=0;x<cat.labels.size();++x)
                {
                table.set(nRows, x, cat.labels.get(x));
                }

            table.set(nRows, cat.labels.size(), counts.get(cat));
            ++nRows;
            }
        GATKReport report = new GATKReport();
        report.addTable(table);
        report.print(this.out);
        out.flush();

        logger.info("TraversalDone");

        }

    }

Compiling and testing


gatk.jar=/path/to/GenomeAnalysisTK.jar
VCF=input.vcf.gz
test: dist/mygatk.jar
        java -cp ${gatk.jar}:dist/mygatk.jar \
                org.broadinstitute.gatk.engine.CommandLineGATK \
                -R /path/to/ref.fasta \
                -T CountPredictions -V ${VCF} -ID

dist/mygatk.jar : ${gatk.jar} ./mygatk/CountPredictions.java
        mkdir -p tmp dist
        javac -d tmp -cp .:${gatk.jar}  ./mygatk/CountPredictions.java
        jar cfv $@ -C tmp .
        rm -rf tmp

Output:


INFO  16:47:16,545 HelpFormatter - ---------------------------------------------------------------------------------- 
INFO  16:47:16,547 HelpFormatter - The Genome Analysis Toolkit (GATK) v3.6-0-g89b7209, Compiled 2016/06/01 22:27:29 
INFO  16:47:16,547 HelpFormatter - Copyright (c) 2010-2016 The Broad Institute 
INFO  16:47:16,547 HelpFormatter - For support and documentation go to https://www.broadinstitute.org/gatk 
INFO  16:47:16,714 HelpFormatter - [Fri Jan 13 16:47:16 CET 2017] Executing on Linux 3.10.0-327.36.3.el7.x86_64 amd64 
INFO  16:47:16,714 HelpFormatter - Java HotSpot(TM) 64-Bit Server VM 1.8.0_102-b14 JdkDeflater 
INFO  16:47:16,717 HelpFormatter - Program Args: -R ref.fasta -T CountPredictions -V input.vcf.gz -ID 
INFO  16:47:16,724 HelpFormatter - Date/Time: 2017/01/13 16:47:16 
INFO  16:47:16,725 HelpFormatter - ---------------------------------------------------------------------------------- 
INFO  16:47:16,725 HelpFormatter - ---------------------------------------------------------------------------------- 
INFO  16:47:16,740 GenomeAnalysisEngine - Strictness is SILENT 
INFO  16:47:16,829 GenomeAnalysisEngine - Downsampling Settings: Method: BY_SAMPLE, Target Coverage: 1000 
WARN  16:47:16,881 IndexDictionaryUtils - Track variant doesn't have a sequence dictionary built in, skipping dictionary validation 
INFO  16:47:16,936 GenomeAnalysisEngine - Preparing for traversal 
INFO  16:47:16,940 GenomeAnalysisEngine - Done preparing for traversal 
INFO  16:47:16,940 ProgressMeter - [INITIALIZATION COMPLETE; STARTING PROCESSING] 
INFO  16:47:16,941 ProgressMeter -                 | processed |    time |    per 1M |           |   total | remaining 
INFO  16:47:16,941 ProgressMeter -        Location |     sites | elapsed |     sites | completed | runtime |   runtime 
INFO  16:47:46,946 ProgressMeter -  chr22:26185095    167649.0    30.0 s       3.0 m       91.0%    32.0 s       2.0 s 
#:GATKReport.v1.1:1
#:GATKTable:2:1:%s:%s:;
#:GATKTable:Variants:Variants input.vcf.gz
IN_DBSNP  COUNT 
.         139984

INFO  16:47:52,501 CountPredictions - TraversalDone 
INFO  16:47:52,502 ProgressMeter -            done    202931.0    35.0 s       2.9 m       91.1%    38.0 s       3.0 s 
INFO  16:47:52,502 ProgressMeter - Total runtime 35.56 secs, 0.59 min, 0.01 hours 
------------------------------------------------------------------------------------------
Done. There were 1 WARN messages, the first 1 are repeated below.
WARN  16:47:16,881 IndexDictionaryUtils - Track variant doesn't have a sequence dictionary built in, skipping dictionary validation 
------------------------------------------------------------------------------------------