{- Husky A pure functional WFST decoder for automatic speech recognition Developed by Takahiro Shinozaki Supported by JST Research Seeds Quest Program 2010.1-2010.12 -} {- Copyright (c) 2010 Takahiro Shinozaki All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Takahiro Shinozaki nor the names of other contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR THE CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -} {-# LANGUAGE BangPatterns #-} -- Version 1.0 2010.12.28 Initial version -- To compile with ghc: ghc -O2 --make husky.hs import System import Text.Printf import Data.Binary import Data.Binary.Get import Data.Binary.IEEE754 import Control.Monad import qualified Data.IntMap as Map import qualified Data.ByteString.Lazy.Char8 as L8 import qualified Data.Vector as V import qualified Data.Vector.Unboxed as U import qualified Data.MemoCombinators as Memo import qualified Data.List as List main = do args <- getArgs case args of [config, wfstf, spdff, scpf] -> do { config <- readConfig config; wfst <- wfstFromFile wfstf config; seq wfst (print (finalSt wfst)); spdfs <- spdfsFromFile spdff; seq spdfs (print (inspectSpdfs spdfs)); scp <- readScpFile scpf; mapM_ (decodeWfstMain config wfst spdfs) scp } _ -> putStrLn "error: husky config wfstf spdff scpf" decodeWfstMain config wfst spdfs feaf = do { print feaf; feas <- htkParmFromFile feaf; seq feas (print (List.length (feaVec feas))); let lat = decodeWfst config wfst spdfs (feaVec feas) out = (backTrack (lat) (finalState wfst)) in do { print (snd out); print $ filter (>0) $ map arcOut $ fst out;} } -- Config params ---------------------------------------- data DcdConfig = DcdConfig { insPenalty :: Double, lmWeight :: Double, band :: Int, beam :: Double } deriving (Show, Eq) readConfig :: String -> IO DcdConfig readConfig fname = do {buf <- L8.readFile fname; return (List.foldl configPsr (DcdConfig 0 10 10000 200) (map L8.unpack (L8.words buf)))} configPsr :: DcdConfig -> String -> DcdConfig configPsr (DcdConfig insp lmw bnd bm) definition = (DcdConfig mInsp mLmw mBnd mBm) -- This is not efficient as all the variables are checked every time. Assumes the config file is small. where mInsp = getVal "insPenalty=" insp mLmw = getVal "lmWeight=" lmw mBnd = getVal "band=" bnd mBm = getVal "beam=" bm getVal defname dfv = case List.stripPrefix defname definition of Just val -> read val _ -> dfv -- Feature vectors -------------------------------------- type FeaVector = U.Vector Double data FeaInfo = FeaInfo { nSamples :: Word32, sampPeriod :: Word32, sampSize :: Word16, parmKind :: Word16, feaVec :: [FeaVector] } instance Show FeaInfo where show (FeaInfo ns sp ss pk fv) = "FeaInfo nSamples=" List.++ (show ns) List.++ " sampPeriod=" List.++ (show sp) List.++ " sampSize=" List.++ (show ss) List.++ " parmKind=" List.++ (show pk) List.++ (List.foldl (\x y -> (x List.++ "\n" List.++ y)) "" (map show fv)) htkParmFromFile :: String -> IO FeaInfo htkParmFromFile fname = do {buf <- L8.readFile fname; return (runGet readHtkParm buf)} readHtkParm :: Get (FeaInfo) readHtkParm = do { nSamples <- getWord32be; sampPeriod <- getWord32be; sampSize <- getWord16be; parmKind <- getWord16be; feas <- replicateM (fromIntegral nSamples) (readVec (fromIntegral (div sampSize 4))); return (FeaInfo nSamples sampPeriod sampSize parmKind feas) } where readVec :: Int -> Get (U.Vector Double) readVec dim = do { tmp <- replicateM dim getFloat32be; return (U.fromList (List.map (fromRational.toRational) tmp))} -- Read feature list file readScpFile :: String -> IO [String] readScpFile fname = do {buf <- L8.readFile fname; return (List.map L8.unpack (L8.words buf))} -- HMM state observation density ------------------------ data GMixture = GMixture { name :: [Char], nmix :: Int, weights :: (U.Vector Double), means :: [(U.Vector Double)], vars :: [(U.Vector Double)], gconsts :: (U.Vector Double) } spdfsFromFile :: String -> IO (V.Vector GMixture) spdfsFromFile fname = do {buf <- L8.readFile fname; return (readSpdfs buf)} readSpdfs :: L8.ByteString -> (V.Vector GMixture) readSpdfs buf = V.fromList $ gmReadAll $ map L8.unpack $ L8.words buf where gmReadAll buf | buf == [] = [] | otherwise = gmm:(gmReadAll rst) where (!gmm, rst) = gmRead buf gmRead :: [String] -> (GMixture, [String]) gmRead ("~s":name:"":nmix:rst) = ((GMixture name (toInt nmix) (U.fromList (List.map tplw mixTmp)) (List.map tplm mixTmp) (List.map tplv mixTmp) (U.fromList (List.map tplc mixTmp))), gmRst) where (!mixTmp, gmRst) = gcNRead (toInt nmix) [] rst tplw (!w,!m,!v,!c) = w tplm (!w,!m,!v,!c) = m tplv (!w,!m,!v,!c) = v tplc (!w,!m,!v,!c) = c gcNRead !iter !out buf | iter == 0 = (out, buf) | otherwise = gcNRead (iter-1) ((mweight, mean, var, gconst):out) gcRst where (!mweight, !mean, !var, !gconst, gcRst) = gcRead buf gcRead :: [String] -> (Double, (U.Vector Double), (U.Vector Double), Double, [String]) gcRead ("":mid:mweight:rst) = ((toDouble mweight), mean, var, gconst, gconstRst) where (!mean, meanRst) = meanRead rst (!var, varRst) = varRead meanRst (!gconst, gconstRst) = gconstRead varRst meanRead ("":dim:rst) = (mean, meanRst) where (!mean, meanRst) = vecRead (toInt dim) rst varRead ("":dim:rst) = (var, varRst) where (!var, varRst) = vecRead (toInt dim) rst gconstRead ("":gconst:rst) = ((toDouble gconst), rst) vecRead !dim buf = (U.fromList ((List.map read (take dim buf))::[Double]), (drop dim buf)) toInt !wd = (read wd)::Int toDouble !wd = (read wd)::Double inspectSpdfs spdfs = V.foldl (+) 0.0 (V.map inspectGMM spdfs) where inspectGMM spdf = U.sum (weights spdf) + List.sum (List.map U.sum (means spdf)) + List.sum (List.map U.sum (vars spdf)) + U.sum (gconsts spdf) -- WFST definitions -------------------------------------- -- Assumptions are 1) the initial state is the first state in the file, -- 2) there is only one final state, 3) Final state does not have -- associated cost, 4) arcs are sorted by its start and end node IDs. -- Arc type Arc = (Int, Int, Int, Int, Double) type Path = ([Arc], Double) -- WFST data WFST = WFST { -- startArray[s] = (index of arcs in arcArray that start with node s, num of arcs having start node s) startArray :: U.Vector (Int, Int), arcArray :: U.Vector Arc, initialSt :: Int, finalSt :: Int } -- deriving (Show) -- Interface functions for Arc arcStart :: Arc -> Int arcStart (p, n, i, o, w) = p arcEnd :: Arc -> Int arcEnd (p, n, i, o, w) = n arcIn :: Arc -> Int arcIn (p, n, i, o, w) = i arcOut :: Arc -> Int arcOut (p, n, i, o, w) = o arcWt :: Arc -> Double arcWt (p, n, i, o, w) = w pathEnd a = arcEnd (last a) pathStart a = arcStart (head a) initState :: WFST -> Int initState (WFST a1 a2 a3 a4) = a3 finalState :: WFST -> Int finalState (WFST a1 a2 a3 a4) = a4 startIdx :: WFST -> Int -> (Int, Int) startIdx wfst idx = (startArray wfst) U.! idx findNextEpsArcs :: WFST -> Int -> [Arc] findNextEpsArcs wfst state = epsArcs where (start, num) = (startArray wfst) U.! state epsArcs = U.toList $ U.takeWhile (\x -> arcIn x == 0) (U.unsafeSlice start num (arcArray wfst)) findNextHmmArcs :: WFST -> Int -> (Int->Double) -> [Arc] findNextHmmArcs wfst state hmmObsProb = arcs where (start, num) = (startArray wfst) U.! state arcs = U.foldl sel [] (U.unsafeSlice start num (arcArray wfst)) sel a b = if arcIn b > 0 then (hmmarcs b):a else a -- input label > 0 is a HMM state hmmarcs (p, n, i, o, w) = a1 -- hmmarcs (p, n, i, o, w) | w > 0.0 = [a1, a2] -- | otherwise = [a1] where a1 = (p, n, i, o, (w + sscore)) -- a2 = (p, p, i, 0, ((log (1.0-exp(-w))) + sscore)) -- simulate HMM self loop sscore = hmmObsProb i wfstFromFile :: String -> DcdConfig -> IO WFST wfstFromFile fname conf = do {buf <- L8.readFile fname; return (readWFST buf conf)} readWFST :: L8.ByteString -> DcdConfig -> WFST readWFST buf conf = WFST stIndices arcArray initSt finSt where lines = List.filter (\x -> x /= L8.empty) $ L8.splitWith (=='\n') buf -- arc array arcArrayTmp = U.fromList $ List.map (line2arc . L8.words) lines arcArray = U.filter (\x -> (arcEnd x >= 0)) arcArrayTmp line2arc [s,e,i,o,w] = (toInt s, toInt e, toInt i, toInt o, (toDouble w)) line2arc [s,e,i,o] = (toInt s, toInt e, toInt i, toInt o, 0.0) line2arc [s] = (toInt s, -1, 0, 0, 0.0) toInt bs = case L8.readInt bs of Just (a,rst) -> a _ -> -1 toDouble bs = (read (L8.unpack bs))::Double -- Index of arc start position and its numbers [(pos, num)] in arcArray. -- Ex. when arc start = [0,1,1,1,2,3,5,5] in arcArray, we want [(0,1), (1,3), (4,1), (5,1), (6,0), (6,2)] stIndices = U.fromList $ List.reverse $ List.zip revStartPoss revStartNums revStartNums = List.zipWith (-) ((U.length arcArray):revStartPoss) revStartPoss revStartPoss = posLst $ U.foldl starts ([], 0, (-1)) arcArray where starts (tmpLst, idx, lastV) ac = case acStart-lastV of 0 -> (tmpLst, (idx+1), lastV) 1 -> (idx:tmpLst, (idx+1), acStart) num -> (((List.replicate num idx) List.++ tmpLst), (idx+1), acStart) where acStart = arcStart ac posLst (a, b, c) = a -- initSt and finSt initSt = arcStart (arcArrayTmp U.! 0) finSt = case U.find isFin arcArrayTmp of Just arc -> arcStart arc _ -> -1 where isFin arc = arcEnd arc == -1 -- Decoding ---------------------------------------------- -- Make empty Hypothesis Path Map (HPM) emptyHypPathMap = Map.empty -- Insert a path and score to hpm. If a path having the same next state exist in hpm, -- either of the paths with smaller score is retained. insertHypPathMap hpm path score = Map.insertWith keepSmaller (pathEnd path) (path, score) hpm where keepSmaller a1 a2 = if snd a1 > snd a2 then a2 else a1 comparePathScore :: Path -> Path -> Ordering comparePathScore (arcs1, sc1) (arcs2, sc2) = if sc1 < sc2 then LT else (if sc1 == sc2 then EQ else GT) -- Depth first search of epsilon reachable nodes from the end of path. -- Push path and paths derived from the path by epsilon transitions to hpm. -- Return hpm expandEps hpm wfst path score = expandEpsSub (insertHypPathMap hpm path score) (findNextEpsArcs wfst (pathEnd path)) path score where expandEpsSub mHpm arcLst basePath mScore | arcLst == [] = mHpm | otherwise = expandEpsSub (expandEpsSub newMHpm epsArcLst (basePath List.++ [arc]) (mScore+aWght)) arcLstTail basePath mScore where arc = head arcLst arcLstTail = tail arcLst aWght = arcWt arc newMHpm = (insertHypPathMap mHpm (basePath List.++ [arc]) (mScore+aWght)) epsArcLst = findNextEpsArcs wfst (arcEnd arc) -- Expand a state expandState hpm wfst state hmmObsProb score = expandStateSub hpm score (findNextHmmArcs wfst state hmmObsProb) where expandStateSub mHpm mScore arcLst | arcLst == [] = mHpm -- arcLst is a list of arcs start from state and match to isym | otherwise = expandStateSub (expandEps mHpm wfst [arc] (mScore+aWght)) mScore arcLstTail where arc = head arcLst arcLstTail = tail arcLst aWght = arcWt arc -- Expand active hyps expandStates wfst activeArcSet hmmObsProb = expandStatesSub emptyHypPathMap activeArcSet where expandStatesSub hpm mActiveArcSet | mActiveArcSet == [] = Map.elems hpm | otherwise = expandStatesSub (expandState hpm wfst state hmmObsProb score) (tail mActiveArcSet) where hyp = head mActiveArcSet path = fst hyp -- path score = snd hyp -- score state = pathEnd path -- Lattice is a list of (list of active hypothesis at each frame) -- Ex. [[([Arc {start = 0, end = 0, insymbol = 0, outsymbol = 0, weight = 0.0}],0.0)]] -- "a hypothesis at each frame" is a partial path corresponding to an isym x. -- If there is an epsilon input symbol in WFST, the partial path has length longer than 1 decodeWfst config wfst spdfs feaSeq = decodeWfstSub feaSeq initlat where initlat=[Map.elems $ expandEps emptyHypPathMap wfst [((initState wfst), (initState wfst), 0, 0, 0.0)] 0] decodeWfstSub fseq lattice | fseq == [] = lattice | activeArcSet == [] = [] | otherwise = decodeWfstSub (tail fseq) ((expandStates wfst prunedArcSet hmmObsProbCache):lattice) where activeArcSet = head lattice compHyp x y = compare (snd x) (snd y) hypsInBand = List.take (band config) $ List.sortBy compHyp activeArcSet bstHypScore = snd (List.head hypsInBand) inBeam hyp = (snd hyp) - bstHypScore < (beam config) prunedArcSet = List.takeWhile inBeam hypsInBand hmmObsProbCache = Memo.integral hmmObsProb hmmObsProb :: Int -> Double hmmObsProb st = -(gmixLProb (spdfs V.! (st-1)) (head fseq)) -- Backtrack viterbi lattice and find the best hypothesis backTrack lattice endState = backTrackSub lattice endState [] 0 where backTrackSub mLattice mEndState pathTmp scoreTmp | mLattice == [] = (tail pathTmp,scoreTmp) | partPathScoreSet == [] = ([],1e10) | otherwise = backTrackSub (tail mLattice) (pathStart bestPartPath) (bestPartPath List.++ pathTmp) (if pathTmp == [] then bestPartScore else scoreTmp) where partPathScoreSet = head mLattice matchEndPath pathScore = mEndState == (pathEnd $ fst pathScore) (bestPartPath, bestPartScore) = List.minimumBy comparePathScore $ filter matchEndPath partPathScoreSet -- State probability functions (log probability of x by a Gaussian mixture gmix) gmixLProb gmix x = log (U.sum (U.zipWith (\a b -> (exp(a)*b)) lcompPs (weights gmix))) where mxdiff = List.map (\a -> U.zipWith (-) a x) (means gmix) mxdiffsqv = List.zipWith (U.zipWith (\a b -> a^2/b)) mxdiff (vars gmix) lcompPs = U.zipWith (\a b -> -(a+b)/2.0) (U.fromList (List.map U.sum mxdiffsqv)) (gconsts gmix)